diff options
Diffstat (limited to 'src')
418 files changed, 15008 insertions, 8289 deletions
diff --git a/src/.clang-format b/src/.clang-format index aae039dd77..ef7a0ef5c7 100644 --- a/src/.clang-format +++ b/src/.clang-format @@ -1,9 +1,9 @@ Language: Cpp AccessModifierOffset: -4 -AlignAfterOpenBracket: false +AlignAfterOpenBracket: true AlignEscapedNewlinesLeft: true AlignTrailingComments: true -AllowAllParametersOfDeclarationOnNextLine: false +AllowAllParametersOfDeclarationOnNextLine: true AllowShortBlocksOnASingleLine: false AllowShortCaseLabelsOnASingleLine: true AllowShortFunctionsOnASingleLine: All diff --git a/src/Makefile.am b/src/Makefile.am index 2b004691fd..cd3cc95707 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -110,9 +110,9 @@ BITCOIN_CORE_H = \ banman.h \ base58.h \ bech32.h \ - bloom.h \ blockencodings.h \ blockfilter.h \ + bloom.h \ chain.h \ chainparams.h \ chainparamsbase.h \ @@ -133,6 +133,7 @@ BITCOIN_CORE_H = \ core_io.h \ core_memusage.h \ cuckoocache.h \ + dbwrapper.h \ flatfile.h \ fs.h \ httprpc.h \ @@ -148,7 +149,6 @@ BITCOIN_CORE_H = \ interfaces/wallet.h \ key.h \ key_io.h \ - dbwrapper.h \ limitedmap.h \ logging.h \ logging/timer.h \ @@ -167,6 +167,7 @@ BITCOIN_CORE_H = \ node/context.h \ node/psbt.h \ node/transaction.h \ + node/ui_interface.h \ node/utxo_snapshot.h \ noui.h \ optional.h \ @@ -184,6 +185,7 @@ BITCOIN_CORE_H = \ reverse_iterator.h \ rpc/blockchain.h \ rpc/client.h \ + rpc/mining.h \ rpc/protocol.h \ rpc/rawtransaction_util.h \ rpc/register.h \ @@ -205,13 +207,12 @@ BITCOIN_CORE_H = \ support/events.h \ support/lockedpool.h \ sync.h \ - threadsafety.h \ threadinterrupt.h \ + threadsafety.h \ timedata.h \ torcontrol.h \ txdb.h \ txmempool.h \ - ui_interface.h \ undo.h \ util/asmap.h \ util/bip32.h \ @@ -220,8 +221,6 @@ BITCOIN_CORE_H = \ util/error.h \ util/fees.h \ util/golombrice.h \ - util/spanparsing.h \ - util/system.h \ util/macros.h \ util/memory.h \ util/message.h \ @@ -229,18 +228,23 @@ BITCOIN_CORE_H = \ util/rbf.h \ util/ref.h \ util/settings.h \ + util/spanparsing.h \ util/string.h \ + util/system.h \ util/threadnames.h \ util/time.h \ util/translation.h \ + util/ui_change_type.h \ util/url.h \ util/vector.h \ validation.h \ validationinterface.h \ versionbits.h \ versionbitsinfo.h \ - walletinitinterface.h \ + wallet/bdb.h \ wallet/coincontrol.h \ + wallet/coinselection.h \ + wallet/context.h \ wallet/crypter.h \ wallet/db.h \ wallet/feebumper.h \ @@ -254,7 +258,7 @@ BITCOIN_CORE_H = \ wallet/walletdb.h \ wallet/wallettool.h \ wallet/walletutil.h \ - wallet/coinselection.h \ + walletinitinterface.h \ warnings.h \ zmq/zmqabstractnotifier.h \ zmq/zmqconfig.h\ @@ -283,16 +287,16 @@ libbitcoin_server_a_SOURCES = \ blockfilter.cpp \ chain.cpp \ consensus/tx_verify.cpp \ + dbwrapper.cpp \ flatfile.cpp \ httprpc.cpp \ httpserver.cpp \ index/base.cpp \ index/blockfilterindex.cpp \ index/txindex.cpp \ + init.cpp \ interfaces/chain.cpp \ interfaces/node.cpp \ - init.cpp \ - dbwrapper.cpp \ miner.cpp \ net.cpp \ net_processing.cpp \ @@ -301,6 +305,7 @@ libbitcoin_server_a_SOURCES = \ node/context.cpp \ node/psbt.cpp \ node/transaction.cpp \ + node/ui_interface.cpp \ noui.cpp \ policy/fees.cpp \ policy/rbf.cpp \ @@ -319,7 +324,6 @@ libbitcoin_server_a_SOURCES = \ torcontrol.cpp \ txdb.cpp \ txmempool.cpp \ - ui_interface.cpp \ validation.cpp \ validationinterface.cpp \ versionbits.cpp \ @@ -349,7 +353,9 @@ libbitcoin_wallet_a_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) libbitcoin_wallet_a_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) libbitcoin_wallet_a_SOURCES = \ interfaces/wallet.cpp \ + wallet/bdb.cpp \ wallet/coincontrol.cpp \ + wallet/context.cpp \ wallet/crypter.cpp \ wallet/db.cpp \ wallet/feebumper.cpp \ diff --git a/src/Makefile.bench.include b/src/Makefile.bench.include index 766c0fca54..c224ca7bf6 100644 --- a/src/Makefile.bench.include +++ b/src/Makefile.bench.include @@ -29,9 +29,12 @@ bench_bench_bitcoin_SOURCES = \ bench/crypto_hash.cpp \ bench/ccoins_caching.cpp \ bench/gcs_filter.cpp \ + bench/hashpadding.cpp \ bench/merkle_root.cpp \ bench/mempool_eviction.cpp \ bench/mempool_stress.cpp \ + bench/nanobench.h \ + bench/nanobench.cpp \ bench/rpc_blockchain.cpp \ bench/rpc_mempool.cpp \ bench/util_time.cpp \ diff --git a/src/Makefile.qt.include b/src/Makefile.qt.include index 13bfea7646..848053e841 100644 --- a/src/Makefile.qt.include +++ b/src/Makefile.qt.include @@ -25,6 +25,7 @@ QT_FORMS_UI = \ qt/forms/openuridialog.ui \ qt/forms/optionsdialog.ui \ qt/forms/overviewpage.ui \ + qt/forms/psbtoperationsdialog.ui \ qt/forms/receivecoinsdialog.ui \ qt/forms/receiverequestdialog.ui \ qt/forms/debugwindow.ui \ @@ -61,6 +62,7 @@ QT_MOC_CPP = \ qt/moc_overviewpage.cpp \ qt/moc_peertablemodel.cpp \ qt/moc_paymentserver.cpp \ + qt/moc_psbtoperationsdialog.cpp \ qt/moc_qrimagewidget.cpp \ qt/moc_qvalidatedlineedit.cpp \ qt/moc_qvaluecombobox.cpp \ @@ -132,6 +134,7 @@ BITCOIN_QT_H = \ qt/paymentserver.h \ qt/peertablemodel.h \ qt/platformstyle.h \ + qt/psbtoperationsdialog.h \ qt/qrimagewidget.h \ qt/qvalidatedlineedit.h \ qt/qvaluecombobox.h \ @@ -245,6 +248,7 @@ BITCOIN_QT_WALLET_CPP = \ qt/openuridialog.cpp \ qt/overviewpage.cpp \ qt/paymentserver.cpp \ + qt/psbtoperationsdialog.cpp \ qt/qrimagewidget.cpp \ qt/receivecoinsdialog.cpp \ qt/receiverequestdialog.cpp \ @@ -272,9 +276,7 @@ if ENABLE_WALLET BITCOIN_QT_CPP += $(BITCOIN_QT_WALLET_CPP) endif # ENABLE_WALLET -RES_IMAGES = - -RES_MOVIES = $(wildcard $(srcdir)/qt/res/movies/spinner-*.png) +RES_ANIMATION = $(wildcard $(srcdir)/qt/res/animation/spinner-*.png) BITCOIN_RC = qt/res/bitcoin-qt-res.rc @@ -286,7 +288,7 @@ qt_libbitcoinqt_a_CXXFLAGS = $(AM_CXXFLAGS) $(QT_PIE_FLAGS) qt_libbitcoinqt_a_OBJCXXFLAGS = $(AM_OBJCXXFLAGS) $(QT_PIE_FLAGS) qt_libbitcoinqt_a_SOURCES = $(BITCOIN_QT_CPP) $(BITCOIN_QT_H) $(QT_FORMS_UI) \ - $(QT_QRC) $(QT_QRC_LOCALE) $(QT_TS) $(RES_ICONS) $(RES_IMAGES) $(RES_MOVIES) + $(QT_QRC) $(QT_QRC_LOCALE) $(QT_TS) $(RES_ICONS) $(RES_ANIMATION) if TARGET_DARWIN qt_libbitcoinqt_a_SOURCES += $(BITCOIN_MM) endif @@ -357,7 +359,7 @@ $(QT_QRC_LOCALE_CPP): $(QT_QRC_LOCALE) $(QT_QM) $(SED) -e '/^\*\*.*Created:/d' -e '/^\*\*.*by:/d' > $@ @rm $(@D)/temp_$(<F) -$(QT_QRC_CPP): $(QT_QRC) $(QT_FORMS_H) $(RES_ICONS) $(RES_IMAGES) $(RES_MOVIES) +$(QT_QRC_CPP): $(QT_QRC) $(QT_FORMS_H) $(RES_ICONS) $(RES_ANIMATION) @test -f $(RCC) $(AM_V_GEN) QT_SELECT=$(QT_SELECT) $(RCC) -name bitcoin $< | \ $(SED) -e '/^\*\*.*Created:/d' -e '/^\*\*.*by:/d' > $@ diff --git a/src/Makefile.test.include b/src/Makefile.test.include index 7909cb4a0f..c3e46c0def 100644 --- a/src/Makefile.test.include +++ b/src/Makefile.test.include @@ -10,7 +10,9 @@ FUZZ_TARGETS = \ test/fuzz/addrman_deserialize \ test/fuzz/asmap \ test/fuzz/asmap_direct \ + test/fuzz/autofile \ test/fuzz/banentry_deserialize \ + test/fuzz/banman \ test/fuzz/base_encode_decode \ test/fuzz/bech32 \ test/fuzz/block \ @@ -28,10 +30,19 @@ FUZZ_TARGETS = \ test/fuzz/blockundo_deserialize \ test/fuzz/bloom_filter \ test/fuzz/bloomfilter_deserialize \ + test/fuzz/buffered_file \ test/fuzz/chain \ test/fuzz/checkqueue \ test/fuzz/coins_deserialize \ test/fuzz/coins_view \ + test/fuzz/crypto \ + test/fuzz/crypto_aes256 \ + test/fuzz/crypto_aes256cbc \ + test/fuzz/crypto_chacha20 \ + test/fuzz/crypto_chacha20_poly1305_aead \ + test/fuzz/crypto_common \ + test/fuzz/crypto_hkdf_hmac_sha256_l32 \ + test/fuzz/crypto_poly1305 \ test/fuzz/cuckoocache \ test/fuzz/decode_tx \ test/fuzz/descriptor_parse \ @@ -52,6 +63,7 @@ FUZZ_TARGETS = \ test/fuzz/key_io \ test/fuzz/key_origin_info_deserialize \ test/fuzz/kitchen_sink \ + test/fuzz/load_external_block_file \ test/fuzz/locale \ test/fuzz/merkle_block_deserialize \ test/fuzz/merkleblock \ @@ -71,6 +83,7 @@ FUZZ_TARGETS = \ test/fuzz/partial_merkle_tree_deserialize \ test/fuzz/partially_signed_transaction_deserialize \ test/fuzz/policy_estimator \ + test/fuzz/policy_estimator_io \ test/fuzz/pow \ test/fuzz/prefilled_transaction_deserialize \ test/fuzz/prevector \ @@ -230,6 +243,7 @@ BITCOIN_TESTS =\ test/net_tests.cpp \ test/netbase_tests.cpp \ test/pmt_tests.cpp \ + test/policy_fee_tests.cpp \ test/policyestimator_tests.cpp \ test/pow_tests.cpp \ test/prevector_tests.cpp \ @@ -251,6 +265,7 @@ BITCOIN_TESTS =\ test/skiplist_tests.cpp \ test/streams_tests.cpp \ test/sync_tests.cpp \ + test/system_tests.cpp \ test/util_threadnames_tests.cpp \ test/timedata_tests.cpp \ test/torcontrol_tests.cpp \ @@ -261,6 +276,7 @@ BITCOIN_TESTS =\ test/uint256_tests.cpp \ test/util_tests.cpp \ test/validation_block_tests.cpp \ + test/validation_chainstate_tests.cpp \ test/validation_chainstatemanager_tests.cpp \ test/validation_flush_tests.cpp \ test/validationinterface_tests.cpp \ @@ -346,12 +362,24 @@ test_fuzz_asmap_direct_LDADD = $(FUZZ_SUITE_LD_COMMON) test_fuzz_asmap_direct_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) test_fuzz_asmap_direct_SOURCES = test/fuzz/asmap_direct.cpp +test_fuzz_autofile_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_autofile_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_autofile_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_autofile_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_autofile_SOURCES = test/fuzz/autofile.cpp + test_fuzz_banentry_deserialize_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) -DBANENTRY_DESERIALIZE=1 test_fuzz_banentry_deserialize_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) test_fuzz_banentry_deserialize_LDADD = $(FUZZ_SUITE_LD_COMMON) test_fuzz_banentry_deserialize_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) test_fuzz_banentry_deserialize_SOURCES = test/fuzz/deserialize.cpp +test_fuzz_banman_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_banman_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_banman_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_banman_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_banman_SOURCES = test/fuzz/banman.cpp + test_fuzz_base_encode_decode_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) test_fuzz_base_encode_decode_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) test_fuzz_base_encode_decode_LDADD = $(FUZZ_SUITE_LD_COMMON) @@ -454,6 +482,12 @@ test_fuzz_bloomfilter_deserialize_LDADD = $(FUZZ_SUITE_LD_COMMON) test_fuzz_bloomfilter_deserialize_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) test_fuzz_bloomfilter_deserialize_SOURCES = test/fuzz/deserialize.cpp +test_fuzz_buffered_file_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_buffered_file_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_buffered_file_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_buffered_file_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_buffered_file_SOURCES = test/fuzz/buffered_file.cpp + test_fuzz_chain_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) test_fuzz_chain_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) test_fuzz_chain_LDADD = $(FUZZ_SUITE_LD_COMMON) @@ -478,6 +512,54 @@ test_fuzz_coins_view_LDADD = $(FUZZ_SUITE_LD_COMMON) test_fuzz_coins_view_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) test_fuzz_coins_view_SOURCES = test/fuzz/coins_view.cpp +test_fuzz_crypto_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_crypto_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_crypto_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_crypto_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_crypto_SOURCES = test/fuzz/crypto.cpp + +test_fuzz_crypto_aes256_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_crypto_aes256_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_crypto_aes256_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_crypto_aes256_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_crypto_aes256_SOURCES = test/fuzz/crypto_aes256.cpp + +test_fuzz_crypto_aes256cbc_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_crypto_aes256cbc_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_crypto_aes256cbc_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_crypto_aes256cbc_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_crypto_aes256cbc_SOURCES = test/fuzz/crypto_aes256cbc.cpp + +test_fuzz_crypto_chacha20_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_crypto_chacha20_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_crypto_chacha20_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_crypto_chacha20_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_crypto_chacha20_SOURCES = test/fuzz/crypto_chacha20.cpp + +test_fuzz_crypto_chacha20_poly1305_aead_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_crypto_chacha20_poly1305_aead_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_crypto_chacha20_poly1305_aead_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_crypto_chacha20_poly1305_aead_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_crypto_chacha20_poly1305_aead_SOURCES = test/fuzz/crypto_chacha20_poly1305_aead.cpp + +test_fuzz_crypto_common_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_crypto_common_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_crypto_common_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_crypto_common_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_crypto_common_SOURCES = test/fuzz/crypto_common.cpp + +test_fuzz_crypto_hkdf_hmac_sha256_l32_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_crypto_hkdf_hmac_sha256_l32_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_crypto_hkdf_hmac_sha256_l32_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_crypto_hkdf_hmac_sha256_l32_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_crypto_hkdf_hmac_sha256_l32_SOURCES = test/fuzz/crypto_hkdf_hmac_sha256_l32.cpp + +test_fuzz_crypto_poly1305_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_crypto_poly1305_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_crypto_poly1305_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_crypto_poly1305_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_crypto_poly1305_SOURCES = test/fuzz/crypto_poly1305.cpp + test_fuzz_cuckoocache_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) test_fuzz_cuckoocache_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) test_fuzz_cuckoocache_LDADD = $(FUZZ_SUITE_LD_COMMON) @@ -598,6 +680,12 @@ test_fuzz_kitchen_sink_LDADD = $(FUZZ_SUITE_LD_COMMON) test_fuzz_kitchen_sink_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) test_fuzz_kitchen_sink_SOURCES = test/fuzz/kitchen_sink.cpp +test_fuzz_load_external_block_file_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_load_external_block_file_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_load_external_block_file_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_load_external_block_file_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_load_external_block_file_SOURCES = test/fuzz/load_external_block_file.cpp + test_fuzz_locale_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) test_fuzz_locale_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) test_fuzz_locale_LDADD = $(FUZZ_SUITE_LD_COMMON) @@ -718,6 +806,12 @@ test_fuzz_policy_estimator_LDADD = $(FUZZ_SUITE_LD_COMMON) test_fuzz_policy_estimator_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) test_fuzz_policy_estimator_SOURCES = test/fuzz/policy_estimator.cpp +test_fuzz_policy_estimator_io_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) +test_fuzz_policy_estimator_io_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) +test_fuzz_policy_estimator_io_LDADD = $(FUZZ_SUITE_LD_COMMON) +test_fuzz_policy_estimator_io_LDFLAGS = $(RELDFLAGS) $(AM_LDFLAGS) $(LIBTOOL_APP_LDFLAGS) +test_fuzz_policy_estimator_io_SOURCES = test/fuzz/policy_estimator_io.cpp + test_fuzz_pow_CPPFLAGS = $(AM_CPPFLAGS) $(BITCOIN_INCLUDES) test_fuzz_pow_CXXFLAGS = $(AM_CXXFLAGS) $(PIE_FLAGS) test_fuzz_pow_LDADD = $(FUZZ_SUITE_LD_COMMON) @@ -1114,7 +1208,7 @@ nodist_test_test_bitcoin_SOURCES = $(GENERATED_TEST_FILES) $(BITCOIN_TESTS): $(GENERATED_TEST_FILES) -CLEAN_BITCOIN_TEST = test/*.gcda test/*.gcno $(GENERATED_TEST_FILES) $(BITCOIN_TESTS:=.log) +CLEAN_BITCOIN_TEST = test/*.gcda test/*.gcno test/fuzz/*.gcda test/fuzz/*.gcno $(GENERATED_TEST_FILES) $(BITCOIN_TESTS:=.log) CLEANFILES += $(CLEAN_BITCOIN_TEST) @@ -1144,8 +1238,8 @@ endif if TARGET_WINDOWS else if ENABLE_BENCH - @echo "Running bench/bench_bitcoin -evals=1 -scaling=0..." - $(BENCH_BINARY) -evals=1 -scaling=0 > /dev/null + @echo "Running bench/bench_bitcoin ..." + $(BENCH_BINARY) > /dev/null endif endif $(AM_V_at)$(MAKE) $(AM_MAKEFLAGS) -C secp256k1 check diff --git a/src/addrdb.cpp b/src/addrdb.cpp index 835c5d6c65..f3e8a19de2 100644 --- a/src/addrdb.cpp +++ b/src/addrdb.cpp @@ -8,6 +8,7 @@ #include <addrman.h> #include <chainparams.h> #include <clientversion.h> +#include <cstdint> #include <hash.h> #include <random.h> #include <streams.h> @@ -36,7 +37,7 @@ template <typename Data> bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data& data) { // Generate random temporary filename - unsigned short randv = 0; + uint16_t randv = 0; GetRandBytes((unsigned char*)&randv, sizeof(randv)); std::string tmpfn = strprintf("%s.%04x", prefix, randv); diff --git a/src/addrdb.h b/src/addrdb.h index c6d4307d69..8410c3776c 100644 --- a/src/addrdb.h +++ b/src/addrdb.h @@ -17,13 +17,6 @@ class CSubNet; class CAddrMan; class CDataStream; -typedef enum BanReason -{ - BanReasonUnknown = 0, - BanReasonNodeMisbehaving = 1, - BanReasonManuallyAdded = 2 -} BanReason; - class CBanEntry { public: @@ -31,7 +24,6 @@ public: int nVersion; int64_t nCreateTime; int64_t nBanUntil; - uint8_t banReason; CBanEntry() { @@ -44,31 +36,17 @@ public: nCreateTime = nCreateTimeIn; } - explicit CBanEntry(int64_t n_create_time_in, BanReason ban_reason_in) : CBanEntry(n_create_time_in) + SERIALIZE_METHODS(CBanEntry, obj) { - banReason = ban_reason_in; + uint8_t ban_reason = 2; //! For backward compatibility + READWRITE(obj.nVersion, obj.nCreateTime, obj.nBanUntil, ban_reason); } - SERIALIZE_METHODS(CBanEntry, obj) { READWRITE(obj.nVersion, obj.nCreateTime, obj.nBanUntil, obj.banReason); } - void SetNull() { nVersion = CBanEntry::CURRENT_VERSION; nCreateTime = 0; nBanUntil = 0; - banReason = BanReasonUnknown; - } - - std::string banReasonToString() const - { - switch (banReason) { - case BanReasonNodeMisbehaving: - return "node misbehaving"; - case BanReasonManuallyAdded: - return "manually added"; - default: - return "unknown"; - } } }; diff --git a/src/addrman.cpp b/src/addrman.cpp index 7aba340d9d..7636c6bad2 100644 --- a/src/addrman.cpp +++ b/src/addrman.cpp @@ -479,11 +479,15 @@ int CAddrMan::Check_() } #endif -void CAddrMan::GetAddr_(std::vector<CAddress>& vAddr) +void CAddrMan::GetAddr_(std::vector<CAddress>& vAddr, size_t max_addresses, size_t max_pct) { - unsigned int nNodes = ADDRMAN_GETADDR_MAX_PCT * vRandom.size() / 100; - if (nNodes > ADDRMAN_GETADDR_MAX) - nNodes = ADDRMAN_GETADDR_MAX; + size_t nNodes = vRandom.size(); + if (max_pct != 0) { + nNodes = max_pct * nNodes / 100; + } + if (max_addresses != 0) { + nNodes = std::min(nNodes, max_addresses); + } // gather a list of random nodes, skipping those of low quality for (unsigned int n = 0; n < vRandom.size(); n++) { diff --git a/src/addrman.h b/src/addrman.h index 8e82020df0..ca045b91cd 100644 --- a/src/addrman.h +++ b/src/addrman.h @@ -153,12 +153,6 @@ public: //! how recent a successful connection should be before we allow an address to be evicted from tried #define ADDRMAN_REPLACEMENT_HOURS 4 -//! the maximum percentage of nodes to return in a getaddr call -#define ADDRMAN_GETADDR_MAX_PCT 23 - -//! the maximum number of nodes to return in a getaddr call -#define ADDRMAN_GETADDR_MAX 2500 - //! Convenience #define ADDRMAN_TRIED_BUCKET_COUNT (1 << ADDRMAN_TRIED_BUCKET_COUNT_LOG2) #define ADDRMAN_NEW_BUCKET_COUNT (1 << ADDRMAN_NEW_BUCKET_COUNT_LOG2) @@ -261,7 +255,7 @@ protected: #endif //! Select several addresses at once. - void GetAddr_(std::vector<CAddress> &vAddr) EXCLUSIVE_LOCKS_REQUIRED(cs); + void GetAddr_(std::vector<CAddress> &vAddr, size_t max_addresses, size_t max_pct) EXCLUSIVE_LOCKS_REQUIRED(cs); //! Mark an entry as currently-connected-to. void Connected_(const CService &addr, int64_t nTime) EXCLUSIVE_LOCKS_REQUIRED(cs); @@ -638,13 +632,13 @@ public: } //! Return a bunch of addresses, selected at random. - std::vector<CAddress> GetAddr() + std::vector<CAddress> GetAddr(size_t max_addresses, size_t max_pct) { Check(); std::vector<CAddress> vAddr; { LOCK(cs); - GetAddr_(vAddr); + GetAddr_(vAddr, max_addresses, max_pct); } Check(); return vAddr; diff --git a/src/banman.cpp b/src/banman.cpp index 9cc584f0e4..8752185a60 100644 --- a/src/banman.cpp +++ b/src/banman.cpp @@ -6,7 +6,7 @@ #include <banman.h> #include <netaddress.h> -#include <ui_interface.h> +#include <node/ui_interface.h> #include <util/system.h> #include <util/time.h> #include <util/translation.h> @@ -26,7 +26,7 @@ BanMan::BanMan(fs::path ban_file, CClientUIInterface* client_interface, int64_t SweepBanned(); // sweep out unused entries LogPrint(BCLog::NET, "Loaded %d banned node ips/subnets from banlist.dat %dms\n", - banmap.size(), GetTimeMillis() - n_start); + m_banned.size(), GetTimeMillis() - n_start); } else { LogPrintf("Invalid or missing banlist.dat; recreating\n"); SetBannedSetDirty(true); // force write @@ -68,28 +68,13 @@ void BanMan::ClearBanned() if (m_client_interface) m_client_interface->BannedListChanged(); } -int BanMan::IsBannedLevel(CNetAddr net_addr) +bool BanMan::IsDiscouraged(const CNetAddr& net_addr) { - // Returns the most severe level of banning that applies to this address. - // 0 - Not banned - // 1 - Automatic misbehavior ban - // 2 - Any other ban - int level = 0; - auto current_time = GetTime(); LOCK(m_cs_banned); - for (const auto& it : m_banned) { - CSubNet sub_net = it.first; - CBanEntry ban_entry = it.second; - - if (current_time < ban_entry.nBanUntil && sub_net.Match(net_addr)) { - if (ban_entry.banReason != BanReasonNodeMisbehaving) return 2; - level = 1; - } - } - return level; + return m_discouraged.contains(net_addr.GetAddrBytes()); } -bool BanMan::IsBanned(CNetAddr net_addr) +bool BanMan::IsBanned(const CNetAddr& net_addr) { auto current_time = GetTime(); LOCK(m_cs_banned); @@ -104,7 +89,7 @@ bool BanMan::IsBanned(CNetAddr net_addr) return false; } -bool BanMan::IsBanned(CSubNet sub_net) +bool BanMan::IsBanned(const CSubNet& sub_net) { auto current_time = GetTime(); LOCK(m_cs_banned); @@ -118,15 +103,21 @@ bool BanMan::IsBanned(CSubNet sub_net) return false; } -void BanMan::Ban(const CNetAddr& net_addr, const BanReason& ban_reason, int64_t ban_time_offset, bool since_unix_epoch) +void BanMan::Ban(const CNetAddr& net_addr, int64_t ban_time_offset, bool since_unix_epoch) { CSubNet sub_net(net_addr); - Ban(sub_net, ban_reason, ban_time_offset, since_unix_epoch); + Ban(sub_net, ban_time_offset, since_unix_epoch); +} + +void BanMan::Discourage(const CNetAddr& net_addr) +{ + LOCK(m_cs_banned); + m_discouraged.insert(net_addr.GetAddrBytes()); } -void BanMan::Ban(const CSubNet& sub_net, const BanReason& ban_reason, int64_t ban_time_offset, bool since_unix_epoch) +void BanMan::Ban(const CSubNet& sub_net, int64_t ban_time_offset, bool since_unix_epoch) { - CBanEntry ban_entry(GetTime(), ban_reason); + CBanEntry ban_entry(GetTime()); int64_t normalized_ban_time_offset = ban_time_offset; bool normalized_since_unix_epoch = since_unix_epoch; @@ -146,8 +137,8 @@ void BanMan::Ban(const CSubNet& sub_net, const BanReason& ban_reason, int64_t ba } if (m_client_interface) m_client_interface->BannedListChanged(); - //store banlist to disk immediately if user requested ban - if (ban_reason == BanReasonManuallyAdded) DumpBanlist(); + //store banlist to disk immediately + DumpBanlist(); } bool BanMan::Unban(const CNetAddr& net_addr) diff --git a/src/banman.h b/src/banman.h index 6bea2e75e9..f6bfbd1e49 100644 --- a/src/banman.h +++ b/src/banman.h @@ -6,6 +6,7 @@ #define BITCOIN_BANMAN_H #include <addrdb.h> +#include <bloom.h> #include <fs.h> #include <net_types.h> // For banmap_t #include <sync.h> @@ -23,32 +24,55 @@ class CClientUIInterface; class CNetAddr; class CSubNet; -// Denial-of-service detection/prevention -// The idea is to detect peers that are behaving -// badly and disconnect/ban them, but do it in a -// one-coding-mistake-won't-shatter-the-entire-network -// way. -// IMPORTANT: There should be nothing I can give a -// node that it will forward on that will make that -// node's peers drop it. If there is, an attacker -// can isolate a node and/or try to split the network. -// Dropping a node for sending stuff that is invalid -// now but might be valid in a later version is also -// dangerous, because it can cause a network split -// between nodes running old code and nodes running -// new code. +// Banman manages two related but distinct concepts: +// +// 1. Banning. This is configured manually by the user, through the setban RPC. +// If an address or subnet is banned, we never accept incoming connections from +// it and never create outgoing connections to it. We won't gossip its address +// to other peers in addr messages. Banned addresses and subnets are stored to +// banlist.dat on shutdown and reloaded on startup. Banning can be used to +// prevent connections with spy nodes or other griefers. +// +// 2. Discouragement. If a peer misbehaves enough (see Misbehaving() in +// net_processing.cpp), we'll mark that address as discouraged. We still allow +// incoming connections from them, but they're preferred for eviction when +// we receive new incoming connections. We never make outgoing connections to +// them, and do not gossip their address to other peers. This is implemented as +// a bloom filter. We can (probabilistically) test for membership, but can't +// list all discouraged addresses or unmark them as discouraged. Discouragement +// can prevent our limited connection slots being used up by incompatible +// or broken peers. +// +// Neither banning nor discouragement are protections against denial-of-service +// attacks, since if an attacker has a way to waste our resources and we +// disconnect from them and ban that address, it's trivial for them to +// reconnect from another IP address. +// +// Attempting to automatically disconnect or ban any class of peer carries the +// risk of splitting the network. For example, if we banned/disconnected for a +// transaction that fails a policy check and a future version changes the +// policy check so the transaction is accepted, then that transaction could +// cause the network to split between old nodes and new nodes. class BanMan { public: ~BanMan(); BanMan(fs::path ban_file, CClientUIInterface* client_interface, int64_t default_ban_time); - void Ban(const CNetAddr& net_addr, const BanReason& ban_reason, int64_t ban_time_offset = 0, bool since_unix_epoch = false); - void Ban(const CSubNet& sub_net, const BanReason& ban_reason, int64_t ban_time_offset = 0, bool since_unix_epoch = false); + void Ban(const CNetAddr& net_addr, int64_t ban_time_offset = 0, bool since_unix_epoch = false); + void Ban(const CSubNet& sub_net, int64_t ban_time_offset = 0, bool since_unix_epoch = false); + void Discourage(const CNetAddr& net_addr); void ClearBanned(); - int IsBannedLevel(CNetAddr net_addr); - bool IsBanned(CNetAddr net_addr); - bool IsBanned(CSubNet sub_net); + + //! Return whether net_addr is banned + bool IsBanned(const CNetAddr& net_addr); + + //! Return whether sub_net is exactly banned + bool IsBanned(const CSubNet& sub_net); + + //! Return whether net_addr is discouraged. + bool IsDiscouraged(const CNetAddr& net_addr); + bool Unban(const CNetAddr& net_addr); bool Unban(const CSubNet& sub_net); void GetBanned(banmap_t& banmap); @@ -68,6 +92,7 @@ private: CClientUIInterface* m_client_interface = nullptr; CBanDB m_ban_db; const int64_t m_default_ban_time; + CRollingBloomFilter m_discouraged GUARDED_BY(m_cs_banned) {50000, 0.000001}; }; #endif diff --git a/src/base58.cpp b/src/base58.cpp index 6a9e21ffc2..9b2946e7a9 100644 --- a/src/base58.cpp +++ b/src/base58.cpp @@ -141,7 +141,7 @@ std::string EncodeBase58Check(const std::vector<unsigned char>& vchIn) { // add 4-byte hash check to the end std::vector<unsigned char> vch(vchIn); - uint256 hash = Hash(vch.begin(), vch.end()); + uint256 hash = Hash(vch); vch.insert(vch.end(), (unsigned char*)&hash, (unsigned char*)&hash + 4); return EncodeBase58(vch); } @@ -154,7 +154,7 @@ bool DecodeBase58Check(const char* psz, std::vector<unsigned char>& vchRet, int return false; } // re-calculate the checksum, ensure it matches the included 4-byte checksum - uint256 hash = Hash(vchRet.begin(), vchRet.end() - 4); + uint256 hash = Hash(MakeSpan(vchRet).first(vchRet.size() - 4)); if (memcmp(&hash, &vchRet[vchRet.size() - 4], 4) != 0) { vchRet.clear(); return false; diff --git a/src/bench/addrman.cpp b/src/bench/addrman.cpp index cc260df2b8..ebdad5a4b8 100644 --- a/src/bench/addrman.cpp +++ b/src/bench/addrman.cpp @@ -67,52 +67,52 @@ static void FillAddrMan(CAddrMan& addrman) /* Benchmarks */ -static void AddrManAdd(benchmark::State& state) +static void AddrManAdd(benchmark::Bench& bench) { CreateAddresses(); CAddrMan addrman; - while (state.KeepRunning()) { + bench.run([&] { AddAddressesToAddrMan(addrman); addrman.Clear(); - } + }); } -static void AddrManSelect(benchmark::State& state) +static void AddrManSelect(benchmark::Bench& bench) { CAddrMan addrman; FillAddrMan(addrman); - while (state.KeepRunning()) { + bench.run([&] { const auto& address = addrman.Select(); assert(address.GetPort() > 0); - } + }); } -static void AddrManGetAddr(benchmark::State& state) +static void AddrManGetAddr(benchmark::Bench& bench) { CAddrMan addrman; FillAddrMan(addrman); - while (state.KeepRunning()) { - const auto& addresses = addrman.GetAddr(); + bench.run([&] { + const auto& addresses = addrman.GetAddr(2500, 23); assert(addresses.size() > 0); - } + }); } -static void AddrManGood(benchmark::State& state) +static void AddrManGood(benchmark::Bench& bench) { /* Create many CAddrMan objects - one to be modified at each loop iteration. * This is necessary because the CAddrMan::Good() method modifies the * object, affecting the timing of subsequent calls to the same method and * we want to do the same amount of work in every loop iteration. */ - const uint64_t numLoops = state.m_num_iters * state.m_num_evals; + bench.epochs(5).epochIterations(1); - std::vector<CAddrMan> addrmans(numLoops); + std::vector<CAddrMan> addrmans(bench.epochs() * bench.epochIterations()); for (auto& addrman : addrmans) { FillAddrMan(addrman); } @@ -128,13 +128,13 @@ static void AddrManGood(benchmark::State& state) }; uint64_t i = 0; - while (state.KeepRunning()) { + bench.run([&] { markSomeAsGood(addrmans.at(i)); ++i; - } + }); } -BENCHMARK(AddrManAdd, 5); -BENCHMARK(AddrManSelect, 1000000); -BENCHMARK(AddrManGetAddr, 500); -BENCHMARK(AddrManGood, 2); +BENCHMARK(AddrManAdd); +BENCHMARK(AddrManSelect); +BENCHMARK(AddrManGetAddr); +BENCHMARK(AddrManGood); diff --git a/src/bench/base58.cpp b/src/bench/base58.cpp index 0690483d50..00544cba31 100644 --- a/src/bench/base58.cpp +++ b/src/bench/base58.cpp @@ -10,7 +10,7 @@ #include <vector> -static void Base58Encode(benchmark::State& state) +static void Base58Encode(benchmark::Bench& bench) { static const std::array<unsigned char, 32> buff = { { @@ -19,13 +19,13 @@ static void Base58Encode(benchmark::State& state) 200, 24 } }; - while (state.KeepRunning()) { + bench.batch(buff.size()).unit("byte").run([&] { EncodeBase58(buff.data(), buff.data() + buff.size()); - } + }); } -static void Base58CheckEncode(benchmark::State& state) +static void Base58CheckEncode(benchmark::Bench& bench) { static const std::array<unsigned char, 32> buff = { { @@ -36,22 +36,22 @@ static void Base58CheckEncode(benchmark::State& state) }; std::vector<unsigned char> vch; vch.assign(buff.begin(), buff.end()); - while (state.KeepRunning()) { + bench.batch(buff.size()).unit("byte").run([&] { EncodeBase58Check(vch); - } + }); } -static void Base58Decode(benchmark::State& state) +static void Base58Decode(benchmark::Bench& bench) { const char* addr = "17VZNX1SN5NtKa8UQFxwQbFeFc3iqRYhem"; std::vector<unsigned char> vch; - while (state.KeepRunning()) { + bench.batch(strlen(addr)).unit("byte").run([&] { (void) DecodeBase58(addr, vch, 64); - } + }); } -BENCHMARK(Base58Encode, 470 * 1000); -BENCHMARK(Base58CheckEncode, 320 * 1000); -BENCHMARK(Base58Decode, 800 * 1000); +BENCHMARK(Base58Encode); +BENCHMARK(Base58CheckEncode); +BENCHMARK(Base58Decode); diff --git a/src/bench/bech32.cpp b/src/bench/bech32.cpp index 2107840a3a..c74d8d51b3 100644 --- a/src/bench/bech32.cpp +++ b/src/bench/bech32.cpp @@ -3,6 +3,7 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <bench/bench.h> +#include <bench/nanobench.h> #include <bech32.h> #include <util/strencodings.h> @@ -11,26 +12,26 @@ #include <vector> -static void Bech32Encode(benchmark::State& state) +static void Bech32Encode(benchmark::Bench& bench) { std::vector<uint8_t> v = ParseHex("c97f5a67ec381b760aeaf67573bc164845ff39a3bb26a1cee401ac67243b48db"); std::vector<unsigned char> tmp = {0}; tmp.reserve(1 + 32 * 8 / 5); ConvertBits<8, 5, true>([&](unsigned char c) { tmp.push_back(c); }, v.begin(), v.end()); - while (state.KeepRunning()) { + bench.batch(v.size()).unit("byte").run([&] { bech32::Encode("bc", tmp); - } + }); } -static void Bech32Decode(benchmark::State& state) +static void Bech32Decode(benchmark::Bench& bench) { std::string addr = "bc1qkallence7tjawwvy0dwt4twc62qjgaw8f4vlhyd006d99f09"; - while (state.KeepRunning()) { + bench.batch(addr.size()).unit("byte").run([&] { bech32::Decode(addr); - } + }); } -BENCHMARK(Bech32Encode, 800 * 1000); -BENCHMARK(Bech32Decode, 800 * 1000); +BENCHMARK(Bech32Encode); +BENCHMARK(Bech32Decode); diff --git a/src/bench/bench.cpp b/src/bench/bench.cpp index 7b93ef688d..01466d0b6f 100644 --- a/src/bench/bench.cpp +++ b/src/bench/bench.cpp @@ -8,141 +8,73 @@ #include <test/util/setup_common.h> #include <validation.h> -#include <algorithm> -#include <assert.h> -#include <iomanip> -#include <iostream> -#include <numeric> #include <regex> const std::function<void(const std::string&)> G_TEST_LOG_FUN{}; -void benchmark::ConsolePrinter::header() -{ - std::cout << "# Benchmark, evals, iterations, total, min, max, median" << std::endl; -} +namespace { -void benchmark::ConsolePrinter::result(const State& state) +void GenerateTemplateResults(const std::vector<ankerl::nanobench::Result>& benchmarkResults, const std::string& filename, const char* tpl) { - auto results = state.m_elapsed_results; - std::sort(results.begin(), results.end()); - - double total = state.m_num_iters * std::accumulate(results.begin(), results.end(), 0.0); - - double front = 0; - double back = 0; - double median = 0; - - if (!results.empty()) { - front = results.front(); - back = results.back(); - - size_t mid = results.size() / 2; - median = results[mid]; - if (0 == results.size() % 2) { - median = (results[mid] + results[mid + 1]) / 2; - } + if (benchmarkResults.empty() || filename.empty()) { + // nothing to write, bail out + return; } - - std::cout << std::setprecision(6); - std::cout << state.m_name << ", " << state.m_num_evals << ", " << state.m_num_iters << ", " << total << ", " << front << ", " << back << ", " << median << std::endl; -} - -void benchmark::ConsolePrinter::footer() {} -benchmark::PlotlyPrinter::PlotlyPrinter(std::string plotly_url, int64_t width, int64_t height) - : m_plotly_url(plotly_url), m_width(width), m_height(height) -{ -} - -void benchmark::PlotlyPrinter::header() -{ - std::cout << "<html><head>" - << "<script src=\"" << m_plotly_url << "\"></script>" - << "</head><body><div id=\"myDiv\" style=\"width:" << m_width << "px; height:" << m_height << "px\"></div>" - << "<script> var data = [" - << std::endl; -} - -void benchmark::PlotlyPrinter::result(const State& state) -{ - std::cout << "{ " << std::endl - << " name: '" << state.m_name << "', " << std::endl - << " y: ["; - - const char* prefix = ""; - for (const auto& e : state.m_elapsed_results) { - std::cout << prefix << std::setprecision(6) << e; - prefix = ", "; + std::ofstream fout(filename); + if (fout.is_open()) { + ankerl::nanobench::render(tpl, benchmarkResults, fout); + } else { + std::cout << "Could write to file '" << filename << "'" << std::endl; } - std::cout << "]," << std::endl - << " boxpoints: 'all', jitter: 0.3, pointpos: 0, type: 'box'," - << std::endl - << "}," << std::endl; -} -void benchmark::PlotlyPrinter::footer() -{ - std::cout << "]; var layout = { showlegend: false, yaxis: { rangemode: 'tozero', autorange: true } };" - << "Plotly.newPlot('myDiv', data, layout);" - << "</script></body></html>"; + std::cout << "Created '" << filename << "'" << std::endl; } +} // namespace benchmark::BenchRunner::BenchmarkMap& benchmark::BenchRunner::benchmarks() { - static std::map<std::string, Bench> benchmarks_map; + static std::map<std::string, BenchFunction> benchmarks_map; return benchmarks_map; } -benchmark::BenchRunner::BenchRunner(std::string name, benchmark::BenchFunction func, uint64_t num_iters_for_one_second) +benchmark::BenchRunner::BenchRunner(std::string name, benchmark::BenchFunction func) { - benchmarks().insert(std::make_pair(name, Bench{func, num_iters_for_one_second})); + benchmarks().insert(std::make_pair(name, func)); } -void benchmark::BenchRunner::RunAll(Printer& printer, uint64_t num_evals, double scaling, const std::string& filter, bool is_list_only) +void benchmark::BenchRunner::RunAll(const Args& args) { - if (!std::ratio_less_equal<benchmark::clock::period, std::micro>::value) { - std::cerr << "WARNING: Clock precision is worse than microsecond - benchmarks may be less accurate!\n"; - } -#ifdef DEBUG - std::cerr << "WARNING: This is a debug build - may result in slower benchmarks.\n"; -#endif - - std::regex reFilter(filter); + std::regex reFilter(args.regex_filter); std::smatch baseMatch; - printer.header(); - + std::vector<ankerl::nanobench::Result> benchmarkResults; for (const auto& p : benchmarks()) { if (!std::regex_match(p.first, baseMatch, reFilter)) { continue; } - uint64_t num_iters = static_cast<uint64_t>(p.second.num_iters_for_one_second * scaling); - if (0 == num_iters) { - num_iters = 1; - } - State state(p.first, num_evals, num_iters, printer); - if (!is_list_only) { - p.second.func(state); + if (args.is_list_only) { + std::cout << p.first << std::endl; + continue; } - printer.result(state); - } - - printer.footer(); -} - -bool benchmark::State::UpdateTimer(const benchmark::time_point current_time) -{ - if (m_start_time != time_point()) { - std::chrono::duration<double> diff = current_time - m_start_time; - m_elapsed_results.push_back(diff.count() / m_num_iters); - if (m_elapsed_results.size() == m_num_evals) { - return false; + Bench bench; + bench.name(p.first); + if (args.asymptote.empty()) { + p.second(bench); + } else { + for (auto n : args.asymptote) { + bench.complexityN(n); + p.second(bench); + } + std::cout << bench.complexityBigO() << std::endl; } + benchmarkResults.push_back(bench.results().back()); } - m_num_iters_left = m_num_iters - 1; - return true; + GenerateTemplateResults(benchmarkResults, args.output_csv, "# Benchmark, evals, iterations, total, min, max, median\n" + "{{#result}}{{name}}, {{epochs}}, {{average(iterations)}}, {{sumProduct(iterations, elapsed)}}, {{minimum(elapsed)}}, {{maximum(elapsed)}}, {{median(elapsed)}}\n" + "{{/result}}"); + GenerateTemplateResults(benchmarkResults, args.output_json, ankerl::nanobench::templates::json()); } diff --git a/src/bench/bench.h b/src/bench/bench.h index 629bca9a73..bafc7f8716 100644 --- a/src/bench/bench.h +++ b/src/bench/bench.h @@ -11,131 +11,53 @@ #include <string> #include <vector> +#include <bench/nanobench.h> #include <boost/preprocessor/cat.hpp> #include <boost/preprocessor/stringize.hpp> -// Simple micro-benchmarking framework; API mostly matches a subset of the Google Benchmark -// framework (see https://github.com/google/benchmark) -// Why not use the Google Benchmark framework? Because adding Yet Another Dependency -// (that uses cmake as its build system and has lots of features we don't need) isn't -// worth it. - /* * Usage: -static void CODE_TO_TIME(benchmark::State& state) +static void CODE_TO_TIME(benchmark::Bench& bench) { ... do any setup needed... - while (state.KeepRunning()) { + nanobench::Config().run([&] { ... do stuff you want to time... - } + }); ... do any cleanup needed... } -// default to running benchmark for 5000 iterations -BENCHMARK(CODE_TO_TIME, 5000); +BENCHMARK(CODE_TO_TIME); */ namespace benchmark { -// In case high_resolution_clock is steady, prefer that, otherwise use steady_clock. -struct best_clock { - using hi_res_clock = std::chrono::high_resolution_clock; - using steady_clock = std::chrono::steady_clock; - using type = std::conditional<hi_res_clock::is_steady, hi_res_clock, steady_clock>::type; -}; -using clock = best_clock::type; -using time_point = clock::time_point; -using duration = clock::duration; - -class Printer; - -class State -{ -public: - std::string m_name; - uint64_t m_num_iters_left; - const uint64_t m_num_iters; - const uint64_t m_num_evals; - std::vector<double> m_elapsed_results; - time_point m_start_time; - bool UpdateTimer(time_point finish_time); +using ankerl::nanobench::Bench; - State(std::string name, uint64_t num_evals, double num_iters, Printer& printer) : m_name(name), m_num_iters_left(0), m_num_iters(num_iters), m_num_evals(num_evals) - { - } +typedef std::function<void(Bench&)> BenchFunction; - inline bool KeepRunning() - { - if (m_num_iters_left--) { - return true; - } - - bool result = UpdateTimer(clock::now()); - // measure again so runtime of UpdateTimer is not included - m_start_time = clock::now(); - return result; - } +struct Args { + std::string regex_filter; + bool is_list_only; + std::vector<double> asymptote; + std::string output_csv; + std::string output_json; }; -typedef std::function<void(State&)> BenchFunction; - class BenchRunner { - struct Bench { - BenchFunction func; - uint64_t num_iters_for_one_second; - }; - typedef std::map<std::string, Bench> BenchmarkMap; + typedef std::map<std::string, BenchFunction> BenchmarkMap; static BenchmarkMap& benchmarks(); public: - BenchRunner(std::string name, BenchFunction func, uint64_t num_iters_for_one_second); - - static void RunAll(Printer& printer, uint64_t num_evals, double scaling, const std::string& filter, bool is_list_only); -}; + BenchRunner(std::string name, BenchFunction func); -// interface to output benchmark results. -class Printer -{ -public: - virtual ~Printer() {} - virtual void header() = 0; - virtual void result(const State& state) = 0; - virtual void footer() = 0; -}; - -// default printer to console, shows min, max, median. -class ConsolePrinter : public Printer -{ -public: - void header() override; - void result(const State& state) override; - void footer() override; -}; - -// creates box plot with plotly.js -class PlotlyPrinter : public Printer -{ -public: - PlotlyPrinter(std::string plotly_url, int64_t width, int64_t height); - void header() override; - void result(const State& state) override; - void footer() override; - -private: - std::string m_plotly_url; - int64_t m_width; - int64_t m_height; + static void RunAll(const Args& args); }; } - - -// BENCHMARK(foo, num_iters_for_one_second) expands to: benchmark::BenchRunner bench_11foo("foo", num_iterations); -// Choose a num_iters_for_one_second that takes roughly 1 second. The goal is that all benchmarks should take approximately -// the same time, and scaling factor can be used that the total time is appropriate for your system. -#define BENCHMARK(n, num_iters_for_one_second) \ - benchmark::BenchRunner BOOST_PP_CAT(bench_, BOOST_PP_CAT(__LINE__, n))(BOOST_PP_STRINGIZE(n), n, (num_iters_for_one_second)); +// BENCHMARK(foo) expands to: benchmark::BenchRunner bench_11foo("foo"); +#define BENCHMARK(n) \ + benchmark::BenchRunner BOOST_PP_CAT(bench_, BOOST_PP_CAT(__LINE__, n))(BOOST_PP_STRINGIZE(n), n); #endif // BITCOIN_BENCH_BENCH_H diff --git a/src/bench/bench_bitcoin.cpp b/src/bench/bench_bitcoin.cpp index 1b75854210..135659f87f 100644 --- a/src/bench/bench_bitcoin.cpp +++ b/src/bench/bench_bitcoin.cpp @@ -4,37 +4,43 @@ #include <bench/bench.h> +#include <crypto/sha256.h> #include <util/strencodings.h> #include <util/system.h> #include <memory> -static const int64_t DEFAULT_BENCH_EVALUATIONS = 5; static const char* DEFAULT_BENCH_FILTER = ".*"; -static const char* DEFAULT_BENCH_SCALING = "1.0"; -static const char* DEFAULT_BENCH_PRINTER = "console"; -static const char* DEFAULT_PLOT_PLOTLYURL = "https://cdn.plot.ly/plotly-latest.min.js"; -static const int64_t DEFAULT_PLOT_WIDTH = 1024; -static const int64_t DEFAULT_PLOT_HEIGHT = 768; static void SetupBenchArgs(ArgsManager& argsman) { SetupHelpOptions(argsman); - argsman.AddArg("-list", "List benchmarks without executing them. Can be combined with -scaling and -filter", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - argsman.AddArg("-evals=<n>", strprintf("Number of measurement evaluations to perform. (default: %u)", DEFAULT_BENCH_EVALUATIONS), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-list", "List benchmarks without executing them", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); argsman.AddArg("-filter=<regex>", strprintf("Regular expression filter to select benchmark by name (default: %s)", DEFAULT_BENCH_FILTER), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - argsman.AddArg("-scaling=<n>", strprintf("Scaling factor for benchmark's runtime (default: %u)", DEFAULT_BENCH_SCALING), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - argsman.AddArg("-printer=(console|plot)", strprintf("Choose printer format. console: print data to console. plot: Print results as HTML graph (default: %s)", DEFAULT_BENCH_PRINTER), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - argsman.AddArg("-plot-plotlyurl=<uri>", strprintf("URL to use for plotly.js (default: %s)", DEFAULT_PLOT_PLOTLYURL), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - argsman.AddArg("-plot-width=<x>", strprintf("Plot width in pixel (default: %u)", DEFAULT_PLOT_WIDTH), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - argsman.AddArg("-plot-height=<x>", strprintf("Plot height in pixel (default: %u)", DEFAULT_PLOT_HEIGHT), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-asymptote=n1,n2,n3,...", strprintf("Test asymptotic growth of the runtime of an algorithm, if supported by the benchmark"), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-output_csv=<output.csv>", "Generate CSV file with the most important benchmark results.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-output_json=<output.json>", "Generate JSON file with all benchmark results.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); +} + +// parses a comma separated list like "10,20,30,50" +static std::vector<double> parseAsymptote(const std::string& str) { + std::stringstream ss(str); + std::vector<double> numbers; + double d; + char c; + while (ss >> d) { + numbers.push_back(d); + ss >> c; + } + return numbers; } int main(int argc, char** argv) { ArgsManager argsman; SetupBenchArgs(argsman); + SHA256AutoDetect(); std::string error; if (!argsman.ParseParameters(argc, argv, error)) { tfm::format(std::cerr, "Error parsing command line arguments: %s\n", error); @@ -47,34 +53,14 @@ int main(int argc, char** argv) return EXIT_SUCCESS; } - int64_t evaluations = argsman.GetArg("-evals", DEFAULT_BENCH_EVALUATIONS); - std::string regex_filter = argsman.GetArg("-filter", DEFAULT_BENCH_FILTER); - std::string scaling_str = argsman.GetArg("-scaling", DEFAULT_BENCH_SCALING); - bool is_list_only = argsman.GetBoolArg("-list", false); - - if (evaluations == 0) { - return EXIT_SUCCESS; - } else if (evaluations < 0) { - tfm::format(std::cerr, "Error parsing evaluations argument: %d\n", evaluations); - return EXIT_FAILURE; - } - - double scaling_factor; - if (!ParseDouble(scaling_str, &scaling_factor)) { - tfm::format(std::cerr, "Error parsing scaling factor as double: %s\n", scaling_str); - return EXIT_FAILURE; - } - - std::unique_ptr<benchmark::Printer> printer = MakeUnique<benchmark::ConsolePrinter>(); - std::string printer_arg = argsman.GetArg("-printer", DEFAULT_BENCH_PRINTER); - if ("plot" == printer_arg) { - printer.reset(new benchmark::PlotlyPrinter( - argsman.GetArg("-plot-plotlyurl", DEFAULT_PLOT_PLOTLYURL), - argsman.GetArg("-plot-width", DEFAULT_PLOT_WIDTH), - argsman.GetArg("-plot-height", DEFAULT_PLOT_HEIGHT))); - } + benchmark::Args args; + args.regex_filter = argsman.GetArg("-filter", DEFAULT_BENCH_FILTER); + args.is_list_only = argsman.GetBoolArg("-list", false); + args.asymptote = parseAsymptote(argsman.GetArg("-asymptote", "")); + args.output_csv = argsman.GetArg("-output_csv", ""); + args.output_json = argsman.GetArg("-output_json", ""); - benchmark::BenchRunner::RunAll(*printer, evaluations, scaling_factor, regex_filter, is_list_only); + benchmark::BenchRunner::RunAll(args); return EXIT_SUCCESS; } diff --git a/src/bench/block_assemble.cpp b/src/bench/block_assemble.cpp index 268f67cada..3f15f3f856 100644 --- a/src/bench/block_assemble.cpp +++ b/src/bench/block_assemble.cpp @@ -14,7 +14,7 @@ #include <vector> -static void AssembleBlock(benchmark::State& state) +static void AssembleBlock(benchmark::Bench& bench) { TestingSetup test_setup{ CBaseChainParams::REGTEST, @@ -54,9 +54,9 @@ static void AssembleBlock(benchmark::State& state) } } - while (state.KeepRunning()) { + bench.run([&] { PrepareBlock(test_setup.m_node, SCRIPT_PUB); - } + }); } -BENCHMARK(AssembleBlock, 700); +BENCHMARK(AssembleBlock); diff --git a/src/bench/ccoins_caching.cpp b/src/bench/ccoins_caching.cpp index 86f9a0bf67..116de98b14 100644 --- a/src/bench/ccoins_caching.cpp +++ b/src/bench/ccoins_caching.cpp @@ -16,7 +16,7 @@ // characteristics than e.g. reindex timings. But that's not a requirement of // every benchmark." // (https://github.com/bitcoin/bitcoin/issues/7883#issuecomment-224807484) -static void CCoinsCaching(benchmark::State& state) +static void CCoinsCaching(benchmark::Bench& bench) { const ECCVerifyHandle verify_handle; ECC_Start(); @@ -44,11 +44,11 @@ static void CCoinsCaching(benchmark::State& state) // Benchmark. const CTransaction tx_1(t1); - while (state.KeepRunning()) { + bench.run([&] { bool success = AreInputsStandard(tx_1, coins); assert(success); - } + }); ECC_Stop(); } -BENCHMARK(CCoinsCaching, 170 * 1000); +BENCHMARK(CCoinsCaching); diff --git a/src/bench/chacha20.cpp b/src/bench/chacha20.cpp index f1b0a9a989..913e0f8d57 100644 --- a/src/bench/chacha20.cpp +++ b/src/bench/chacha20.cpp @@ -11,7 +11,7 @@ static const uint64_t BUFFER_SIZE_TINY = 64; static const uint64_t BUFFER_SIZE_SMALL = 256; static const uint64_t BUFFER_SIZE_LARGE = 1024*1024; -static void CHACHA20(benchmark::State& state, size_t buffersize) +static void CHACHA20(benchmark::Bench& bench, size_t buffersize) { std::vector<uint8_t> key(32,0); ChaCha20 ctx(key.data(), key.size()); @@ -19,26 +19,26 @@ static void CHACHA20(benchmark::State& state, size_t buffersize) ctx.Seek(0); std::vector<uint8_t> in(buffersize,0); std::vector<uint8_t> out(buffersize,0); - while (state.KeepRunning()) { + bench.batch(in.size()).unit("byte").run([&] { ctx.Crypt(in.data(), out.data(), in.size()); - } + }); } -static void CHACHA20_64BYTES(benchmark::State& state) +static void CHACHA20_64BYTES(benchmark::Bench& bench) { - CHACHA20(state, BUFFER_SIZE_TINY); + CHACHA20(bench, BUFFER_SIZE_TINY); } -static void CHACHA20_256BYTES(benchmark::State& state) +static void CHACHA20_256BYTES(benchmark::Bench& bench) { - CHACHA20(state, BUFFER_SIZE_SMALL); + CHACHA20(bench, BUFFER_SIZE_SMALL); } -static void CHACHA20_1MB(benchmark::State& state) +static void CHACHA20_1MB(benchmark::Bench& bench) { - CHACHA20(state, BUFFER_SIZE_LARGE); + CHACHA20(bench, BUFFER_SIZE_LARGE); } -BENCHMARK(CHACHA20_64BYTES, 500000); -BENCHMARK(CHACHA20_256BYTES, 250000); -BENCHMARK(CHACHA20_1MB, 340); +BENCHMARK(CHACHA20_64BYTES); +BENCHMARK(CHACHA20_256BYTES); +BENCHMARK(CHACHA20_1MB); diff --git a/src/bench/chacha_poly_aead.cpp b/src/bench/chacha_poly_aead.cpp index df10f27d03..3b1d3e697a 100644 --- a/src/bench/chacha_poly_aead.cpp +++ b/src/bench/chacha_poly_aead.cpp @@ -21,7 +21,7 @@ static const unsigned char k2[32] = {0}; static ChaCha20Poly1305AEAD aead(k1, 32, k2, 32); -static void CHACHA20_POLY1305_AEAD(benchmark::State& state, size_t buffersize, bool include_decryption) +static void CHACHA20_POLY1305_AEAD(benchmark::Bench& bench, size_t buffersize, bool include_decryption) { std::vector<unsigned char> in(buffersize + CHACHA20_POLY1305_AEAD_AAD_LEN + POLY1305_TAGLEN, 0); std::vector<unsigned char> out(buffersize + CHACHA20_POLY1305_AEAD_AAD_LEN + POLY1305_TAGLEN, 0); @@ -29,7 +29,7 @@ static void CHACHA20_POLY1305_AEAD(benchmark::State& state, size_t buffersize, b uint64_t seqnr_aad = 0; int aad_pos = 0; uint32_t len = 0; - while (state.KeepRunning()) { + bench.batch(buffersize).unit("byte").run([&] { // encrypt or decrypt the buffer with a static key assert(aead.Crypt(seqnr_payload, seqnr_aad, aad_pos, out.data(), out.size(), in.data(), buffersize, true)); @@ -53,70 +53,71 @@ static void CHACHA20_POLY1305_AEAD(benchmark::State& state, size_t buffersize, b seqnr_aad = 0; aad_pos = 0; } - } + }); } -static void CHACHA20_POLY1305_AEAD_64BYTES_ONLY_ENCRYPT(benchmark::State& state) +static void CHACHA20_POLY1305_AEAD_64BYTES_ONLY_ENCRYPT(benchmark::Bench& bench) { - CHACHA20_POLY1305_AEAD(state, BUFFER_SIZE_TINY, false); + CHACHA20_POLY1305_AEAD(bench, BUFFER_SIZE_TINY, false); } -static void CHACHA20_POLY1305_AEAD_256BYTES_ONLY_ENCRYPT(benchmark::State& state) +static void CHACHA20_POLY1305_AEAD_256BYTES_ONLY_ENCRYPT(benchmark::Bench& bench) { - CHACHA20_POLY1305_AEAD(state, BUFFER_SIZE_SMALL, false); + CHACHA20_POLY1305_AEAD(bench, BUFFER_SIZE_SMALL, false); } -static void CHACHA20_POLY1305_AEAD_1MB_ONLY_ENCRYPT(benchmark::State& state) +static void CHACHA20_POLY1305_AEAD_1MB_ONLY_ENCRYPT(benchmark::Bench& bench) { - CHACHA20_POLY1305_AEAD(state, BUFFER_SIZE_LARGE, false); + CHACHA20_POLY1305_AEAD(bench, BUFFER_SIZE_LARGE, false); } -static void CHACHA20_POLY1305_AEAD_64BYTES_ENCRYPT_DECRYPT(benchmark::State& state) +static void CHACHA20_POLY1305_AEAD_64BYTES_ENCRYPT_DECRYPT(benchmark::Bench& bench) { - CHACHA20_POLY1305_AEAD(state, BUFFER_SIZE_TINY, true); + CHACHA20_POLY1305_AEAD(bench, BUFFER_SIZE_TINY, true); } -static void CHACHA20_POLY1305_AEAD_256BYTES_ENCRYPT_DECRYPT(benchmark::State& state) +static void CHACHA20_POLY1305_AEAD_256BYTES_ENCRYPT_DECRYPT(benchmark::Bench& bench) { - CHACHA20_POLY1305_AEAD(state, BUFFER_SIZE_SMALL, true); + CHACHA20_POLY1305_AEAD(bench, BUFFER_SIZE_SMALL, true); } -static void CHACHA20_POLY1305_AEAD_1MB_ENCRYPT_DECRYPT(benchmark::State& state) +static void CHACHA20_POLY1305_AEAD_1MB_ENCRYPT_DECRYPT(benchmark::Bench& bench) { - CHACHA20_POLY1305_AEAD(state, BUFFER_SIZE_LARGE, true); + CHACHA20_POLY1305_AEAD(bench, BUFFER_SIZE_LARGE, true); } // Add Hash() (dbl-sha256) bench for comparison -static void HASH(benchmark::State& state, size_t buffersize) +static void HASH(benchmark::Bench& bench, size_t buffersize) { uint8_t hash[CHash256::OUTPUT_SIZE]; std::vector<uint8_t> in(buffersize,0); - while (state.KeepRunning()) - CHash256().Write(in.data(), in.size()).Finalize(hash); + bench.batch(in.size()).unit("byte").run([&] { + CHash256().Write(in).Finalize(hash); + }); } -static void HASH_64BYTES(benchmark::State& state) +static void HASH_64BYTES(benchmark::Bench& bench) { - HASH(state, BUFFER_SIZE_TINY); + HASH(bench, BUFFER_SIZE_TINY); } -static void HASH_256BYTES(benchmark::State& state) +static void HASH_256BYTES(benchmark::Bench& bench) { - HASH(state, BUFFER_SIZE_SMALL); + HASH(bench, BUFFER_SIZE_SMALL); } -static void HASH_1MB(benchmark::State& state) +static void HASH_1MB(benchmark::Bench& bench) { - HASH(state, BUFFER_SIZE_LARGE); + HASH(bench, BUFFER_SIZE_LARGE); } -BENCHMARK(CHACHA20_POLY1305_AEAD_64BYTES_ONLY_ENCRYPT, 500000); -BENCHMARK(CHACHA20_POLY1305_AEAD_256BYTES_ONLY_ENCRYPT, 250000); -BENCHMARK(CHACHA20_POLY1305_AEAD_1MB_ONLY_ENCRYPT, 340); -BENCHMARK(CHACHA20_POLY1305_AEAD_64BYTES_ENCRYPT_DECRYPT, 500000); -BENCHMARK(CHACHA20_POLY1305_AEAD_256BYTES_ENCRYPT_DECRYPT, 250000); -BENCHMARK(CHACHA20_POLY1305_AEAD_1MB_ENCRYPT_DECRYPT, 340); -BENCHMARK(HASH_64BYTES, 500000); -BENCHMARK(HASH_256BYTES, 250000); -BENCHMARK(HASH_1MB, 340); +BENCHMARK(CHACHA20_POLY1305_AEAD_64BYTES_ONLY_ENCRYPT); +BENCHMARK(CHACHA20_POLY1305_AEAD_256BYTES_ONLY_ENCRYPT); +BENCHMARK(CHACHA20_POLY1305_AEAD_1MB_ONLY_ENCRYPT); +BENCHMARK(CHACHA20_POLY1305_AEAD_64BYTES_ENCRYPT_DECRYPT); +BENCHMARK(CHACHA20_POLY1305_AEAD_256BYTES_ENCRYPT_DECRYPT); +BENCHMARK(CHACHA20_POLY1305_AEAD_1MB_ENCRYPT_DECRYPT); +BENCHMARK(HASH_64BYTES); +BENCHMARK(HASH_256BYTES); +BENCHMARK(HASH_1MB); diff --git a/src/bench/checkblock.cpp b/src/bench/checkblock.cpp index 2b2c78905e..dc0aa4031c 100644 --- a/src/bench/checkblock.cpp +++ b/src/bench/checkblock.cpp @@ -14,21 +14,21 @@ // a block off the wire, but before we can relay the block on to peers using // compact block relay. -static void DeserializeBlockTest(benchmark::State& state) +static void DeserializeBlockTest(benchmark::Bench& bench) { CDataStream stream(benchmark::data::block413567, SER_NETWORK, PROTOCOL_VERSION); char a = '\0'; stream.write(&a, 1); // Prevent compaction - while (state.KeepRunning()) { + bench.unit("block").run([&] { CBlock block; stream >> block; bool rewound = stream.Rewind(benchmark::data::block413567.size()); assert(rewound); - } + }); } -static void DeserializeAndCheckBlockTest(benchmark::State& state) +static void DeserializeAndCheckBlockTest(benchmark::Bench& bench) { CDataStream stream(benchmark::data::block413567, SER_NETWORK, PROTOCOL_VERSION); char a = '\0'; @@ -36,7 +36,7 @@ static void DeserializeAndCheckBlockTest(benchmark::State& state) const auto chainParams = CreateChainParams(CBaseChainParams::MAIN); - while (state.KeepRunning()) { + bench.unit("block").run([&] { CBlock block; // Note that CBlock caches its checked state, so we need to recreate it here stream >> block; bool rewound = stream.Rewind(benchmark::data::block413567.size()); @@ -45,8 +45,8 @@ static void DeserializeAndCheckBlockTest(benchmark::State& state) BlockValidationState validationState; bool checked = CheckBlock(block, validationState, chainParams->GetConsensus()); assert(checked); - } + }); } -BENCHMARK(DeserializeBlockTest, 130); -BENCHMARK(DeserializeAndCheckBlockTest, 160); +BENCHMARK(DeserializeBlockTest); +BENCHMARK(DeserializeAndCheckBlockTest); diff --git a/src/bench/checkqueue.cpp b/src/bench/checkqueue.cpp index e052681181..19d7bc0dbc 100644 --- a/src/bench/checkqueue.cpp +++ b/src/bench/checkqueue.cpp @@ -24,7 +24,7 @@ static const unsigned int QUEUE_BATCH_SIZE = 128; // This Benchmark tests the CheckQueue with a slightly realistic workload, // where checks all contain a prevector that is indirect 50% of the time // and there is a little bit of work done between calls to Add. -static void CCheckQueueSpeedPrevectorJob(benchmark::State& state) +static void CCheckQueueSpeedPrevectorJob(benchmark::Bench& bench) { const ECCVerifyHandle verify_handle; ECC_Start(); @@ -47,23 +47,28 @@ static void CCheckQueueSpeedPrevectorJob(benchmark::State& state) for (auto x = 0; x < std::max(MIN_CORES, GetNumCores()); ++x) { tg.create_thread([&]{queue.Thread();}); } - while (state.KeepRunning()) { + + // create all the data once, then submit copies in the benchmark. + FastRandomContext insecure_rand(true); + std::vector<std::vector<PrevectorJob>> vBatches(BATCHES); + for (auto& vChecks : vBatches) { + vChecks.reserve(BATCH_SIZE); + for (size_t x = 0; x < BATCH_SIZE; ++x) + vChecks.emplace_back(insecure_rand); + } + + bench.minEpochIterations(10).batch(BATCH_SIZE * BATCHES).unit("job").run([&] { // Make insecure_rand here so that each iteration is identical. - FastRandomContext insecure_rand(true); CCheckQueueControl<PrevectorJob> control(&queue); - std::vector<std::vector<PrevectorJob>> vBatches(BATCHES); - for (auto& vChecks : vBatches) { - vChecks.reserve(BATCH_SIZE); - for (size_t x = 0; x < BATCH_SIZE; ++x) - vChecks.emplace_back(insecure_rand); + for (auto vChecks : vBatches) { control.Add(vChecks); } // control waits for completion by RAII, but // it is done explicitly here for clarity control.Wait(); - } + }); tg.interrupt_all(); tg.join_all(); ECC_Stop(); } -BENCHMARK(CCheckQueueSpeedPrevectorJob, 1400); +BENCHMARK(CCheckQueueSpeedPrevectorJob); diff --git a/src/bench/coin_selection.cpp b/src/bench/coin_selection.cpp index d6d5e67c5b..3a71a6ca03 100644 --- a/src/bench/coin_selection.cpp +++ b/src/bench/coin_selection.cpp @@ -27,11 +27,11 @@ static void addCoin(const CAmount& nValue, const CWallet& wallet, std::vector<st // same one over and over isn't too useful. Generating random isn't useful // either for measurements." // (https://github.com/bitcoin/bitcoin/issues/7883#issuecomment-224807484) -static void CoinSelection(benchmark::State& state) +static void CoinSelection(benchmark::Bench& bench) { NodeContext node; auto chain = interfaces::MakeChain(node); - CWallet wallet(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet wallet(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); wallet.SetupLegacyScriptPubKeyMan(); std::vector<std::unique_ptr<CWalletTx>> wtxs; LOCK(wallet.cs_wallet); @@ -51,7 +51,7 @@ static void CoinSelection(benchmark::State& state) const CoinEligibilityFilter filter_standard(1, 6, 0); const CoinSelectionParams coin_selection_params(true, 34, 148, CFeeRate(0), 0); - while (state.KeepRunning()) { + bench.run([&] { std::set<CInputCoin> setCoinsRet; CAmount nValueRet; bool bnb_used; @@ -59,13 +59,13 @@ static void CoinSelection(benchmark::State& state) assert(success); assert(nValueRet == 1003 * COIN); assert(setCoinsRet.size() == 2); - } + }); } typedef std::set<CInputCoin> CoinSet; static NodeContext testNode; static auto testChain = interfaces::MakeChain(testNode); -static CWallet testWallet(testChain.get(), WalletLocation(), WalletDatabase::CreateDummy()); +static CWallet testWallet(testChain.get(), WalletLocation(), CreateDummyWalletDatabase()); std::vector<std::unique_ptr<CWalletTx>> wtxn; // Copied from src/wallet/test/coinselector_tests.cpp @@ -91,7 +91,7 @@ static CAmount make_hard_case(int utxos, std::vector<OutputGroup>& utxo_pool) return target; } -static void BnBExhaustion(benchmark::State& state) +static void BnBExhaustion(benchmark::Bench& bench) { // Setup testWallet.SetupLegacyScriptPubKeyMan(); @@ -100,7 +100,7 @@ static void BnBExhaustion(benchmark::State& state) CAmount value_ret = 0; CAmount not_input_fees = 0; - while (state.KeepRunning()) { + bench.run([&] { // Benchmark CAmount target = make_hard_case(17, utxo_pool); SelectCoinsBnB(utxo_pool, target, 0, selection, value_ret, not_input_fees); // Should exhaust @@ -108,8 +108,8 @@ static void BnBExhaustion(benchmark::State& state) // Cleanup utxo_pool.clear(); selection.clear(); - } + }); } -BENCHMARK(CoinSelection, 650); -BENCHMARK(BnBExhaustion, 650); +BENCHMARK(CoinSelection); +BENCHMARK(BnBExhaustion); diff --git a/src/bench/crypto_hash.cpp b/src/bench/crypto_hash.cpp index ddcef5121e..36be86bcc8 100644 --- a/src/bench/crypto_hash.cpp +++ b/src/bench/crypto_hash.cpp @@ -16,88 +16,92 @@ /* Number of bytes to hash per iteration */ static const uint64_t BUFFER_SIZE = 1000*1000; -static void RIPEMD160(benchmark::State& state) +static void RIPEMD160(benchmark::Bench& bench) { uint8_t hash[CRIPEMD160::OUTPUT_SIZE]; std::vector<uint8_t> in(BUFFER_SIZE,0); - while (state.KeepRunning()) + bench.batch(in.size()).unit("byte").run([&] { CRIPEMD160().Write(in.data(), in.size()).Finalize(hash); + }); } -static void SHA1(benchmark::State& state) +static void SHA1(benchmark::Bench& bench) { uint8_t hash[CSHA1::OUTPUT_SIZE]; std::vector<uint8_t> in(BUFFER_SIZE,0); - while (state.KeepRunning()) + bench.batch(in.size()).unit("byte").run([&] { CSHA1().Write(in.data(), in.size()).Finalize(hash); + }); } -static void SHA256(benchmark::State& state) +static void SHA256(benchmark::Bench& bench) { uint8_t hash[CSHA256::OUTPUT_SIZE]; std::vector<uint8_t> in(BUFFER_SIZE,0); - while (state.KeepRunning()) + bench.batch(in.size()).unit("byte").run([&] { CSHA256().Write(in.data(), in.size()).Finalize(hash); + }); } -static void SHA256_32b(benchmark::State& state) +static void SHA256_32b(benchmark::Bench& bench) { std::vector<uint8_t> in(32,0); - while (state.KeepRunning()) { + bench.batch(in.size()).unit("byte").run([&] { CSHA256() .Write(in.data(), in.size()) .Finalize(in.data()); - } + }); } -static void SHA256D64_1024(benchmark::State& state) +static void SHA256D64_1024(benchmark::Bench& bench) { std::vector<uint8_t> in(64 * 1024, 0); - while (state.KeepRunning()) { + bench.batch(in.size()).unit("byte").run([&] { SHA256D64(in.data(), in.data(), 1024); - } + }); } -static void SHA512(benchmark::State& state) +static void SHA512(benchmark::Bench& bench) { uint8_t hash[CSHA512::OUTPUT_SIZE]; std::vector<uint8_t> in(BUFFER_SIZE,0); - while (state.KeepRunning()) + bench.batch(in.size()).unit("byte").run([&] { CSHA512().Write(in.data(), in.size()).Finalize(hash); + }); } -static void SipHash_32b(benchmark::State& state) +static void SipHash_32b(benchmark::Bench& bench) { uint256 x; uint64_t k1 = 0; - while (state.KeepRunning()) { + bench.run([&] { *((uint64_t*)x.begin()) = SipHashUint256(0, ++k1, x); - } + }); } -static void FastRandom_32bit(benchmark::State& state) +static void FastRandom_32bit(benchmark::Bench& bench) { FastRandomContext rng(true); - while (state.KeepRunning()) { + bench.run([&] { rng.rand32(); - } + }); } -static void FastRandom_1bit(benchmark::State& state) +static void FastRandom_1bit(benchmark::Bench& bench) { FastRandomContext rng(true); - while (state.KeepRunning()) { + bench.run([&] { rng.randbool(); - } + }); } -BENCHMARK(RIPEMD160, 440); -BENCHMARK(SHA1, 570); -BENCHMARK(SHA256, 340); -BENCHMARK(SHA512, 330); +BENCHMARK(RIPEMD160); +BENCHMARK(SHA1); +BENCHMARK(SHA256); +BENCHMARK(SHA512); -BENCHMARK(SHA256_32b, 4700 * 1000); -BENCHMARK(SipHash_32b, 40 * 1000 * 1000); -BENCHMARK(SHA256D64_1024, 7400); -BENCHMARK(FastRandom_32bit, 110 * 1000 * 1000); -BENCHMARK(FastRandom_1bit, 440 * 1000 * 1000); +BENCHMARK(SHA256_32b); +BENCHMARK(SipHash_32b); +BENCHMARK(SHA256D64_1024); +BENCHMARK(FastRandom_32bit); +BENCHMARK(FastRandom_1bit); diff --git a/src/bench/duplicate_inputs.cpp b/src/bench/duplicate_inputs.cpp index e87f15042b..5745e4276c 100644 --- a/src/bench/duplicate_inputs.cpp +++ b/src/bench/duplicate_inputs.cpp @@ -12,7 +12,7 @@ #include <validation.h> -static void DuplicateInputs(benchmark::State& state) +static void DuplicateInputs(benchmark::Bench& bench) { TestingSetup test_setup{ CBaseChainParams::REGTEST, @@ -61,11 +61,11 @@ static void DuplicateInputs(benchmark::State& state) block.hashMerkleRoot = BlockMerkleRoot(block); - while (state.KeepRunning()) { + bench.run([&] { BlockValidationState cvstate{}; assert(!CheckBlock(block, cvstate, chainparams.GetConsensus(), false, false)); assert(cvstate.GetRejectReason() == "bad-txns-inputs-duplicate"); - } + }); } -BENCHMARK(DuplicateInputs, 10); +BENCHMARK(DuplicateInputs); diff --git a/src/bench/examples.cpp b/src/bench/examples.cpp index f88150200a..dcd615b9da 100644 --- a/src/bench/examples.cpp +++ b/src/bench/examples.cpp @@ -3,31 +3,19 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <bench/bench.h> -#include <util/time.h> - -// Sanity test: this should loop ten times, and -// min/max/average should be close to 100ms. -static void Sleep100ms(benchmark::State& state) -{ - while (state.KeepRunning()) { - UninterruptibleSleep(std::chrono::milliseconds{100}); - } -} - -BENCHMARK(Sleep100ms, 10); // Extremely fast-running benchmark: #include <math.h> volatile double sum = 0.0; // volatile, global so not optimized away -static void Trig(benchmark::State& state) +static void Trig(benchmark::Bench& bench) { double d = 0.01; - while (state.KeepRunning()) { + bench.run([&] { sum += sin(d); d += 0.000001; - } + }); } -BENCHMARK(Trig, 12 * 1000 * 1000); +BENCHMARK(Trig); diff --git a/src/bench/gcs_filter.cpp b/src/bench/gcs_filter.cpp index 535ad35571..ef83242e41 100644 --- a/src/bench/gcs_filter.cpp +++ b/src/bench/gcs_filter.cpp @@ -5,7 +5,7 @@ #include <bench/bench.h> #include <blockfilter.h> -static void ConstructGCSFilter(benchmark::State& state) +static void ConstructGCSFilter(benchmark::Bench& bench) { GCSFilter::ElementSet elements; for (int i = 0; i < 10000; ++i) { @@ -16,14 +16,14 @@ static void ConstructGCSFilter(benchmark::State& state) } uint64_t siphash_k0 = 0; - while (state.KeepRunning()) { + bench.batch(elements.size()).unit("elem").run([&] { GCSFilter filter({siphash_k0, 0, 20, 1 << 20}, elements); siphash_k0++; - } + }); } -static void MatchGCSFilter(benchmark::State& state) +static void MatchGCSFilter(benchmark::Bench& bench) { GCSFilter::ElementSet elements; for (int i = 0; i < 10000; ++i) { @@ -34,10 +34,10 @@ static void MatchGCSFilter(benchmark::State& state) } GCSFilter filter({0, 0, 20, 1 << 20}, elements); - while (state.KeepRunning()) { + bench.unit("elem").run([&] { filter.Match(GCSFilter::Element()); - } + }); } -BENCHMARK(ConstructGCSFilter, 1000); -BENCHMARK(MatchGCSFilter, 50 * 1000); +BENCHMARK(ConstructGCSFilter); +BENCHMARK(MatchGCSFilter); diff --git a/src/bench/hashpadding.cpp b/src/bench/hashpadding.cpp new file mode 100644 index 0000000000..309cae3723 --- /dev/null +++ b/src/bench/hashpadding.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2015-2018 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <bench/bench.h> +#include <hash.h> +#include <random.h> +#include <uint256.h> + + +static void PrePadded(benchmark::Bench& bench) +{ + + CSHA256 hasher; + + // Setup the salted hasher + uint256 nonce = GetRandHash(); + hasher.Write(nonce.begin(), 32); + hasher.Write(nonce.begin(), 32); + uint256 data = GetRandHash(); + bench.run([&] { + unsigned char out[32]; + CSHA256 h = hasher; + h.Write(data.begin(), 32); + h.Finalize(out); + }); +} + +BENCHMARK(PrePadded); + +static void RegularPadded(benchmark::Bench& bench) +{ + CSHA256 hasher; + + // Setup the salted hasher + uint256 nonce = GetRandHash(); + uint256 data = GetRandHash(); + bench.run([&] { + unsigned char out[32]; + CSHA256 h = hasher; + h.Write(nonce.begin(), 32); + h.Write(data.begin(), 32); + h.Finalize(out); + }); +} + +BENCHMARK(RegularPadded); diff --git a/src/bench/lockedpool.cpp b/src/bench/lockedpool.cpp index 5d943810df..32b060a15a 100644 --- a/src/bench/lockedpool.cpp +++ b/src/bench/lockedpool.cpp @@ -9,10 +9,9 @@ #include <vector> #define ASIZE 2048 -#define BITER 5000 #define MSIZE 2048 -static void BenchLockedPool(benchmark::State& state) +static void BenchLockedPool(benchmark::Bench& bench) { void *synth_base = reinterpret_cast<void*>(0x08000000); const size_t synth_size = 1024*1024; @@ -22,24 +21,22 @@ static void BenchLockedPool(benchmark::State& state) for (int x=0; x<ASIZE; ++x) addr.push_back(nullptr); uint32_t s = 0x12345678; - while (state.KeepRunning()) { - for (int x=0; x<BITER; ++x) { - int idx = s & (addr.size()-1); - if (s & 0x80000000) { - b.free(addr[idx]); - addr[idx] = nullptr; - } else if(!addr[idx]) { - addr[idx] = b.alloc((s >> 16) & (MSIZE-1)); - } - bool lsb = s & 1; - s >>= 1; - if (lsb) - s ^= 0xf00f00f0; // LFSR period 0xf7ffffe0 + bench.run([&] { + int idx = s & (addr.size() - 1); + if (s & 0x80000000) { + b.free(addr[idx]); + addr[idx] = nullptr; + } else if (!addr[idx]) { + addr[idx] = b.alloc((s >> 16) & (MSIZE - 1)); } - } + bool lsb = s & 1; + s >>= 1; + if (lsb) + s ^= 0xf00f00f0; // LFSR period 0xf7ffffe0 + }); for (void *ptr: addr) b.free(ptr); addr.clear(); } -BENCHMARK(BenchLockedPool, 1300); +BENCHMARK(BenchLockedPool); diff --git a/src/bench/mempool_eviction.cpp b/src/bench/mempool_eviction.cpp index 69483f2914..1b9e428c9d 100644 --- a/src/bench/mempool_eviction.cpp +++ b/src/bench/mempool_eviction.cpp @@ -23,7 +23,7 @@ static void AddTx(const CTransactionRef& tx, const CAmount& nFee, CTxMemPool& po // Right now this is only testing eviction performance in an extremely small // mempool. Code needs to be written to generate a much wider variety of // unique transactions for a more meaningful performance measurement. -static void MempoolEviction(benchmark::State& state) +static void MempoolEviction(benchmark::Bench& bench) { TestingSetup test_setup{ CBaseChainParams::REGTEST, @@ -125,7 +125,7 @@ static void MempoolEviction(benchmark::State& state) const CTransactionRef tx6_r{MakeTransactionRef(tx6)}; const CTransactionRef tx7_r{MakeTransactionRef(tx7)}; - while (state.KeepRunning()) { + bench.run([&]() NO_THREAD_SAFETY_ANALYSIS { AddTx(tx1_r, 10000LL, pool); AddTx(tx2_r, 5000LL, pool); AddTx(tx3_r, 20000LL, pool); @@ -135,7 +135,7 @@ static void MempoolEviction(benchmark::State& state) AddTx(tx7_r, 9000LL, pool); pool.TrimToSize(pool.DynamicMemoryUsage() * 3 / 4); pool.TrimToSize(GetVirtualTransactionSize(*tx1_r)); - } + }); } -BENCHMARK(MempoolEviction, 41000); +BENCHMARK(MempoolEviction); diff --git a/src/bench/mempool_stress.cpp b/src/bench/mempool_stress.cpp index 38d8632318..89233e390c 100644 --- a/src/bench/mempool_stress.cpp +++ b/src/bench/mempool_stress.cpp @@ -26,8 +26,13 @@ struct Available { Available(CTransactionRef& ref, size_t tx_count) : ref(ref), tx_count(tx_count){} }; -static void ComplexMemPool(benchmark::State& state) +static void ComplexMemPool(benchmark::Bench& bench) { + int childTxs = 800; + if (bench.complexityN() > 1) { + childTxs = static_cast<int>(bench.complexityN()); + } + FastRandomContext det_rand{true}; std::vector<Available> available_coins; std::vector<CTransactionRef> ordered_coins; @@ -46,7 +51,7 @@ static void ComplexMemPool(benchmark::State& state) ordered_coins.emplace_back(MakeTransactionRef(tx)); available_coins.emplace_back(ordered_coins.back(), tx_counter++); } - for (auto x = 0; x < 800 && !available_coins.empty(); ++x) { + for (auto x = 0; x < childTxs && !available_coins.empty(); ++x) { CMutableTransaction tx = CMutableTransaction(); size_t n_ancestors = det_rand.randrange(10)+1; for (size_t ancestor = 0; ancestor < n_ancestors && !available_coins.empty(); ++ancestor){ @@ -77,13 +82,13 @@ static void ComplexMemPool(benchmark::State& state) TestingSetup test_setup; CTxMemPool pool; LOCK2(cs_main, pool.cs); - while (state.KeepRunning()) { + bench.run([&]() NO_THREAD_SAFETY_ANALYSIS { for (auto& tx : ordered_coins) { AddTx(tx, pool); } pool.TrimToSize(pool.DynamicMemoryUsage() * 3 / 4); pool.TrimToSize(GetVirtualTransactionSize(*ordered_coins.front())); - } + }); } -BENCHMARK(ComplexMemPool, 1); +BENCHMARK(ComplexMemPool); diff --git a/src/bench/merkle_root.cpp b/src/bench/merkle_root.cpp index e84f92feae..ba6629b9f0 100644 --- a/src/bench/merkle_root.cpp +++ b/src/bench/merkle_root.cpp @@ -8,7 +8,7 @@ #include <random.h> #include <uint256.h> -static void MerkleRoot(benchmark::State& state) +static void MerkleRoot(benchmark::Bench& bench) { FastRandomContext rng(true); std::vector<uint256> leaves; @@ -16,11 +16,11 @@ static void MerkleRoot(benchmark::State& state) for (auto& item : leaves) { item = rng.rand256(); } - while (state.KeepRunning()) { + bench.batch(leaves.size()).unit("leaf").run([&] { bool mutation = false; uint256 hash = ComputeMerkleRoot(std::vector<uint256>(leaves), &mutation); leaves[mutation] = hash; - } + }); } -BENCHMARK(MerkleRoot, 800); +BENCHMARK(MerkleRoot); diff --git a/src/bench/nanobench.cpp b/src/bench/nanobench.cpp new file mode 100644 index 0000000000..fcdd86495a --- /dev/null +++ b/src/bench/nanobench.cpp @@ -0,0 +1,6 @@ +// Copyright (c) 2019-2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#define ANKERL_NANOBENCH_IMPLEMENT +#include <bench/nanobench.h> diff --git a/src/bench/nanobench.h b/src/bench/nanobench.h new file mode 100644 index 0000000000..c5379e7fd4 --- /dev/null +++ b/src/bench/nanobench.h @@ -0,0 +1,3225 @@ +// __ _ _______ __ _ _____ ______ _______ __ _ _______ _ _ +// | \ | |_____| | \ | | | |_____] |______ | \ | | |_____| +// | \_| | | | \_| |_____| |_____] |______ | \_| |_____ | | +// +// Microbenchmark framework for C++11/14/17/20 +// https://github.com/martinus/nanobench +// +// Licensed under the MIT License <http://opensource.org/licenses/MIT>. +// SPDX-License-Identifier: MIT +// Copyright (c) 2019-2020 Martin Ankerl <martin.ankerl@gmail.com> +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ANKERL_NANOBENCH_H_INCLUDED +#define ANKERL_NANOBENCH_H_INCLUDED + +// see https://semver.org/ +#define ANKERL_NANOBENCH_VERSION_MAJOR 4 // incompatible API changes +#define ANKERL_NANOBENCH_VERSION_MINOR 0 // backwards-compatible changes +#define ANKERL_NANOBENCH_VERSION_PATCH 0 // backwards-compatible bug fixes + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// public facing api - as minimal as possible +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include <chrono> // high_resolution_clock +#include <cstring> // memcpy +#include <iosfwd> // for std::ostream* custom output target in Config +#include <string> // all names +#include <vector> // holds all results + +#define ANKERL_NANOBENCH(x) ANKERL_NANOBENCH_PRIVATE_##x() + +#define ANKERL_NANOBENCH_PRIVATE_CXX() __cplusplus +#define ANKERL_NANOBENCH_PRIVATE_CXX98() 199711L +#define ANKERL_NANOBENCH_PRIVATE_CXX11() 201103L +#define ANKERL_NANOBENCH_PRIVATE_CXX14() 201402L +#define ANKERL_NANOBENCH_PRIVATE_CXX17() 201703L + +#if ANKERL_NANOBENCH(CXX) >= ANKERL_NANOBENCH(CXX17) +# define ANKERL_NANOBENCH_PRIVATE_NODISCARD() [[nodiscard]] +#else +# define ANKERL_NANOBENCH_PRIVATE_NODISCARD() +#endif + +#if defined(__clang__) +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_PADDED_PUSH() \ + _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wpadded\"") +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_PADDED_POP() _Pragma("clang diagnostic pop") +#else +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_PADDED_PUSH() +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_PADDED_POP() +#endif + +#if defined(__GNUC__) +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_EFFCPP_PUSH() _Pragma("GCC diagnostic push") _Pragma("GCC diagnostic ignored \"-Weffc++\"") +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_EFFCPP_POP() _Pragma("GCC diagnostic pop") +#else +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_EFFCPP_PUSH() +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_EFFCPP_POP() +#endif + +#if defined(ANKERL_NANOBENCH_LOG_ENABLED) +# include <iostream> +# define ANKERL_NANOBENCH_LOG(x) std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << x << std::endl +#else +# define ANKERL_NANOBENCH_LOG(x) +#endif + +#if defined(__linux__) && !defined(ANKERL_NANOBENCH_DISABLE_PERF_COUNTERS) +# define ANKERL_NANOBENCH_PRIVATE_PERF_COUNTERS() 1 +#else +# define ANKERL_NANOBENCH_PRIVATE_PERF_COUNTERS() 0 +#endif + +#if defined(__clang__) +# define ANKERL_NANOBENCH_NO_SANITIZE(...) __attribute__((no_sanitize(__VA_ARGS__))) +#else +# define ANKERL_NANOBENCH_NO_SANITIZE(...) +#endif + +#if defined(_MSC_VER) +# define ANKERL_NANOBENCH_PRIVATE_NOINLINE() __declspec(noinline) +#else +# define ANKERL_NANOBENCH_PRIVATE_NOINLINE() __attribute__((noinline)) +#endif + +// workaround missing "is_trivially_copyable" in g++ < 5.0 +// See https://stackoverflow.com/a/31798726/48181 +#if defined(__GNUC__) && __GNUC__ < 5 +# define ANKERL_NANOBENCH_IS_TRIVIALLY_COPYABLE(...) __has_trivial_copy(__VA_ARGS__) +#else +# define ANKERL_NANOBENCH_IS_TRIVIALLY_COPYABLE(...) std::is_trivially_copyable<__VA_ARGS__>::value +#endif + +// declarations /////////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +using Clock = std::conditional<std::chrono::high_resolution_clock::is_steady, std::chrono::high_resolution_clock, + std::chrono::steady_clock>::type; +class Bench; +struct Config; +class Result; +class Rng; +class BigO; + +/** + * @brief Renders output from a mustache-like template and benchmark results. + * + * The templating facility here is heavily inspired by [mustache - logic-less templates](https://mustache.github.io/). + * It adds a few more features that are necessary to get all of the captured data out of nanobench. Please read the + * excellent [mustache manual](https://mustache.github.io/mustache.5.html) to see what this is all about. + * + * nanobench output has two nested layers, *result* and *measurement*. Here is a hierarchy of the allowed tags: + * + * * `{{#result}}` Marks the begin of the result layer. Whatever comes after this will be instantiated as often as + * a benchmark result is available. Within it, you can use these tags: + * + * * `{{title}}` See Bench::title(). + * + * * `{{name}}` Benchmark name, usually directly provided with Bench::run(), but can also be set with Bench::name(). + * + * * `{{unit}}` Unit, e.g. `byte`. Defaults to `op`, see Bench::title(). + * + * * `{{batch}}` Batch size, see Bench::batch(). + * + * * `{{complexityN}}` Value used for asymptotic complexity calculation. See Bench::complexityN(). + * + * * `{{epochs}}` Number of epochs, see Bench::epochs(). + * + * * `{{clockResolution}}` Accuracy of the clock, i.e. what's the smallest time possible to measure with the clock. + * For modern systems, this can be around 20 ns. This value is automatically determined by nanobench at the first + * benchmark that is run, and used as a static variable throughout the application's runtime. + * + * * `{{clockResolutionMultiple}}` Configuration multiplier for `clockResolution`. See Bench::clockResolutionMultiple(). + * This is the target runtime for each measurement (epoch). That means the more accurate your clock is, the faster + * will be the benchmark. Basing the measurement's runtime on the clock resolution is the main reason why nanobench is so fast. + * + * * `{{maxEpochTime}}` Configuration for a maximum time each measurement (epoch) is allowed to take. Note that at least + * a single iteration will be performed, even when that takes longer than maxEpochTime. See Bench::maxEpochTime(). + * + * * `{{minEpochTime}}` Minimum epoch time, usually not set. See Bench::minEpochTime(). + * + * * `{{minEpochIterations}}` See Bench::minEpochIterations(). + * + * * `{{epochIterations}}` See Bench::epochIterations(). + * + * * `{{warmup}}` Number of iterations used before measuring starts. See Bench::warmup(). + * + * * `{{relative}}` True or false, depending on the setting you have used. See Bench::relative(). + * + * Apart from these tags, it is also possible to use some mathematical operations on the measurement data. The operations + * are of the form `{{command(name)}}`. Currently `name` can be one of `elapsed`, `iterations`. If performance counters + * are available (currently only on current Linux systems), you also have `pagefaults`, `cpucycles`, + * `contextswitches`, `instructions`, `branchinstructions`, and `branchmisses`. All the measuers (except `iterations`) are + * provided for a single iteration (so `elapsed` is the time a single iteration took). The following tags are available: + * + * * `{{median(<name>>)}}` Calculate median of a measurement data set, e.g. `{{median(elapsed)}}`. + * + * * `{{average(<name>)}}` Average (mean) calculation. + * + * * `{{medianAbsolutePercentError(<name>)}}` Calculates MdAPE, the Median Absolute Percentage Error. The MdAPE is an excellent + * metric for the variation of measurements. It is more robust to outliers than the + * [Mean absolute percentage error (M-APE)](https://en.wikipedia.org/wiki/Mean_absolute_percentage_error). + * @f[ + * \mathrm{medianAbsolutePercentError}(e) = \mathrm{median}\{| \frac{e_i - \mathrm{median}\{e\}}{e_i}| \} + * @f] + * E.g. for *elapsed*: First, @f$ \mathrm{median}\{elapsed\} @f$ is calculated. This is used to calculate the absolute percentage + * error to this median for each measurement, as in @f$ | \frac{e_i - \mathrm{median}\{e\}}{e_i}| @f$. All these results + * are sorted, and the middle value is chosen as the median absolute percent error. + * + * This measurement is a bit hard to interpret, but it is very robust against outliers. E.g. a value of 5% means that half of the + * measurements deviate less than 5% from the median, and the other deviate more than 5% from the median. + * + * * `{{sum(<name>)}}` Sums of all the measurements. E.g. `{{sum(iterations)}}` will give you the total number of iterations +* measured in this benchmark. + * + * * `{{minimum(<name>)}}` Minimum of all measurements. + * + * * `{{maximum(<name>)}}` Maximum of all measurements. + * + * * `{{sumProduct(<first>, <second>)}}` Calculates the sum of the products of corresponding measures: + * @f[ + * \mathrm{sumProduct}(a,b) = \sum_{i=1}^{n}a_i\cdot b_i + * @f] + * E.g. to calculate total runtime of the benchmark, you multiply iterations with elapsed time for each measurement, and + * sum these results up: + * `{{sumProduct(iterations, elapsed)}}`. + * + * * `{{#measurement}}` To access individual measurement results, open the begin tag for measurements. + * + * * `{{elapsed}}` Average elapsed time per iteration, in seconds. + * + * * `{{iterations}}` Number of iterations in the measurement. The number of iterations will fluctuate due + * to some applied randomness, to enhance accuracy. + * + * * `{{pagefaults}}` Average number of pagefaults per iteration. + * + * * `{{cpucycles}}` Average number of CPU cycles processed per iteration. + * + * * `{{contextswitches}}` Average number of context switches per iteration. + * + * * `{{instructions}}` Average number of retired instructions per iteration. + * + * * `{{branchinstructions}}` Average number of branches executed per iteration. + * + * * `{{branchmisses}}` Average number of branches that were missed per iteration. + * + * * `{{/measurement}}` Ends the measurement tag. + * + * * `{{/result}}` Marks the end of the result layer. This is the end marker for the template part that will be instantiated + * for each benchmark result. + * + * + * For the layer tags *result* and *measurement* you additionally can use these special markers: + * + * * ``{{#-first}}`` - Begin marker of a template that will be instantiated *only for the first* entry in the layer. Use is only + * allowed between the begin and end marker of the layer allowed. So between ``{{#result}}`` and ``{{/result}}``, or between + * ``{{#measurement}}`` and ``{{/measurement}}``. Finish the template with ``{{/-first}}``. + * + * * ``{{^-first}}`` - Begin marker of a template that will be instantiated *for each except the first* entry in the layer. This, + * this is basically the inversion of ``{{#-first}}``. Use is only allowed between the begin and end marker of the layer allowed. + * So between ``{{#result}}`` and ``{{/result}}``, or between ``{{#measurement}}`` and ``{{/measurement}}``. + * + * * ``{{/-first}}`` - End marker for either ``{{#-first}}`` or ``{{^-first}}``. + * + * * ``{{#-last}}`` - Begin marker of a template that will be instantiated *only for the last* entry in the layer. Use is only + * allowed between the begin and end marker of the layer allowed. So between ``{{#result}}`` and ``{{/result}}``, or between + * ``{{#measurement}}`` and ``{{/measurement}}``. Finish the template with ``{{/-last}}``. + * + * * ``{{^-last}}`` - Begin marker of a template that will be instantiated *for each except the last* entry in the layer. This, + * this is basically the inversion of ``{{#-last}}``. Use is only allowed between the begin and end marker of the layer allowed. + * So between ``{{#result}}`` and ``{{/result}}``, or between ``{{#measurement}}`` and ``{{/measurement}}``. + * + * * ``{{/-last}}`` - End marker for either ``{{#-last}}`` or ``{{^-last}}``. + * + @verbatim embed:rst + + For an overview of all the possible data you can get out of nanobench, please see the tutorial at :ref:`tutorial-template-json`. + + The templates that ship with nanobench are: + + * :cpp:func:`templates::csv() <ankerl::nanobench::templates::csv()>` + * :cpp:func:`templates::json() <ankerl::nanobench::templates::json()>` + * :cpp:func:`templates::htmlBoxplot() <ankerl::nanobench::templates::htmlBoxplot()>` + + @endverbatim + * + * @param mustacheTemplate The template. + * @param bench Benchmark, containing all the results. + * @param out Output for the generated output. + */ +void render(char const* mustacheTemplate, Bench const& bench, std::ostream& out); + +/** + * Same as render(char const* mustacheTemplate, Bench const& bench, std::ostream& out), but for when + * you only have results available. + * + * @param mustacheTemplate The template. + * @param results All the results to be used for rendering. + * @param out Output for the generated output. + */ +void render(char const* mustacheTemplate, std::vector<Result> const& results, std::ostream& out); + +// Contains mustache-like templates +namespace templates { + +/*! + @brief CSV data for the benchmark results. + + Generates a comma-separated values dataset. First line is the header, each following line is a summary of each benchmark run. + + @verbatim embed:rst + See the tutorial at :ref:`tutorial-template-csv` for an example. + @endverbatim + */ +char const* csv() noexcept; + +/*! + @brief HTML output that uses plotly to generate an interactive boxplot chart. See the tutorial for an example output. + + The output uses only the elapsed time, and displays each epoch as a single dot. + @verbatim embed:rst + See the tutorial at :ref:`tutorial-template-html` for an example. + @endverbatim + + @see ankerl::nanobench::render() + */ +char const* htmlBoxplot() noexcept; + +/*! + @brief Template to generate JSON data. + + The generated JSON data contains *all* data that has been generated. All times are as double values, in seconds. The output can get + quite large. + @verbatim embed:rst + See the tutorial at :ref:`tutorial-template-json` for an example. + @endverbatim + */ +char const* json() noexcept; + +} // namespace templates + +namespace detail { + +template <typename T> +struct PerfCountSet; + +class IterationLogic; +class PerformanceCounters; + +#if ANKERL_NANOBENCH(PERF_COUNTERS) +class LinuxPerformanceCounters; +#endif + +} // namespace detail +} // namespace nanobench +} // namespace ankerl + +// definitions //////////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { +namespace detail { + +template <typename T> +struct PerfCountSet { + T pageFaults{}; + T cpuCycles{}; + T contextSwitches{}; + T instructions{}; + T branchInstructions{}; + T branchMisses{}; +}; + +} // namespace detail + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +struct Config { + // actual benchmark config + std::string mBenchmarkTitle = "benchmark"; + std::string mBenchmarkName = "noname"; + std::string mUnit = "op"; + double mBatch = 1.0; + double mComplexityN = -1.0; + size_t mNumEpochs = 11; + size_t mClockResolutionMultiple = static_cast<size_t>(1000); + std::chrono::nanoseconds mMaxEpochTime = std::chrono::milliseconds(100); + std::chrono::nanoseconds mMinEpochTime{}; + uint64_t mMinEpochIterations{1}; + uint64_t mEpochIterations{0}; // If not 0, run *exactly* these number of iterations per epoch. + uint64_t mWarmup = 0; + std::ostream* mOut = nullptr; + bool mShowPerformanceCounters = true; + bool mIsRelative = false; + + Config(); + ~Config(); + Config& operator=(Config const&); + Config& operator=(Config&&); + Config(Config const&); + Config(Config&&) noexcept; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +// Result returned after a benchmark has finished. Can be used as a baseline for relative(). +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class Result { +public: + enum class Measure : size_t { + elapsed, + iterations, + pagefaults, + cpucycles, + contextswitches, + instructions, + branchinstructions, + branchmisses, + _size + }; + + explicit Result(Config const& benchmarkConfig); + + ~Result(); + Result& operator=(Result const&); + Result& operator=(Result&&); + Result(Result const&); + Result(Result&&) noexcept; + + // adds new measurement results + // all values are scaled by iters (except iters...) + void add(Clock::duration totalElapsed, uint64_t iters, detail::PerformanceCounters const& pc); + + ANKERL_NANOBENCH(NODISCARD) Config const& config() const noexcept; + + ANKERL_NANOBENCH(NODISCARD) double median(Measure m) const; + ANKERL_NANOBENCH(NODISCARD) double medianAbsolutePercentError(Measure m) const; + ANKERL_NANOBENCH(NODISCARD) double average(Measure m) const; + ANKERL_NANOBENCH(NODISCARD) double sum(Measure m) const noexcept; + ANKERL_NANOBENCH(NODISCARD) double sumProduct(Measure m1, Measure m2) const noexcept; + ANKERL_NANOBENCH(NODISCARD) double minimum(Measure m) const noexcept; + ANKERL_NANOBENCH(NODISCARD) double maximum(Measure m) const noexcept; + + ANKERL_NANOBENCH(NODISCARD) bool has(Measure m) const noexcept; + ANKERL_NANOBENCH(NODISCARD) double get(size_t idx, Measure m) const; + ANKERL_NANOBENCH(NODISCARD) bool empty() const noexcept; + ANKERL_NANOBENCH(NODISCARD) size_t size() const noexcept; + + // Finds string, if not found, returns _size. + static Measure fromString(std::string const& str); + +private: + Config mConfig{}; + std::vector<std::vector<double>> mNameToMeasurements{}; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +/** + * An extremely fast random generator. Currently, this implements *RomuDuoJr*, developed by Mark Overton. Source: + * http://www.romu-random.org/ + * + * RomuDuoJr is extremely fast and provides reasonable good randomness. Not enough for large jobs, but definitely + * good enough for a benchmarking framework. + * + * * Estimated capacity: @f$ 2^{51} @f$ bytes + * * Register pressure: 4 + * * State size: 128 bits + * + * This random generator is a drop-in replacement for the generators supplied by ``<random>``. It is not + * cryptographically secure. It's intended purpose is to be very fast so that benchmarks that make use + * of randomness are not distorted too much by the random generator. + * + * Rng also provides a few non-standard helpers, optimized for speed. + */ +class Rng final { +public: + /** + * @brief This RNG provides 64bit randomness. + */ + using result_type = uint64_t; + + static constexpr uint64_t(min)(); + static constexpr uint64_t(max)(); + + /** + * As a safety precausion, we don't allow copying. Copying a PRNG would mean you would have two random generators that produce the + * same sequence, which is generally not what one wants. Instead create a new rng with the default constructor Rng(), which is + * automatically seeded from `std::random_device`. If you really need a copy, use copy(). + */ + Rng(Rng const&) = delete; + + /** + * Same as Rng(Rng const&), we don't allow assignment. If you need a new Rng create one with the default constructor Rng(). + */ + Rng& operator=(Rng const&) = delete; + + // moving is ok + Rng(Rng&&) noexcept = default; + Rng& operator=(Rng&&) noexcept = default; + ~Rng() noexcept = default; + + /** + * @brief Creates a new Random generator with random seed. + * + * Instead of a default seed (as the random generators from the STD), this properly seeds the random generator from + * `std::random_device`. It guarantees correct seeding. Note that seeding can be relatively slow, depending on the source of + * randomness used. So it is best to create a Rng once and use it for all your randomness purposes. + */ + Rng(); + + /*! + Creates a new Rng that is seeded with a specific seed. Each Rng created from the same seed will produce the same randomness + sequence. This can be useful for deterministic behavior. + + @verbatim embed:rst + .. note:: + + The random algorithm might change between nanobench releases. Whenever a faster and/or better random + generator becomes available, I will switch the implementation. + @endverbatim + + As per the Romu paper, this seeds the Rng with splitMix64 algorithm and performs 10 initial rounds for further mixing up of the + internal state. + + @param seed The 64bit seed. All values are allowed, even 0. + */ + explicit Rng(uint64_t seed) noexcept; + Rng(uint64_t x, uint64_t y) noexcept; + + /** + * Creates a copy of the Rng, thus the copy provides exactly the same random sequence as the original. + */ + ANKERL_NANOBENCH(NODISCARD) Rng copy() const noexcept; + + /** + * @brief Produces a 64bit random value. This should be very fast, thus it is marked as inline. In my benchmark, this is ~46 times + * faster than `std::default_random_engine` for producing 64bit random values. It seems that the fastest std contender is + * `std::mt19937_64`. Still, this RNG is 2-3 times as fast. + * + * @return uint64_t The next 64 bit random value. + */ + inline uint64_t operator()() noexcept; + + // This is slightly biased. See + + /** + * Generates a random number between 0 and range (excluding range). + * + * The algorithm only produces 32bit numbers, and is slightly biased. The effect is quite small unless your range is close to the + * maximum value of an integer. It is possible to correct the bias with rejection sampling (see + * [here](https://lemire.me/blog/2016/06/30/fast-random-shuffling/), but this is most likely irrelevant in practices for the + * purposes of this Rng. + * + * See Daniel Lemire's blog post [A fast alternative to the modulo + * reduction](https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/) + * + * @param range Upper exclusive range. E.g a value of 3 will generate random numbers 0, 1, 2. + * @return uint32_t Generated random values in range [0, range(. + */ + inline uint32_t bounded(uint32_t range) noexcept; + + // random double in range [0, 1( + // see http://prng.di.unimi.it/ + + /** + * Provides a random uniform double value between 0 and 1. This uses the method described in [Generating uniform doubles in the + * unit interval](http://prng.di.unimi.it/), and is extremely fast. + * + * @return double Uniformly distributed double value in range [0,1(, excluding 1. + */ + inline double uniform01() noexcept; + + /** + * Shuffles all entries in the given container. Although this has a slight bias due to the implementation of bounded(), this is + * preferable to `std::shuffle` because it is over 5 times faster. See Daniel Lemire's blog post [Fast random + * shuffling](https://lemire.me/blog/2016/06/30/fast-random-shuffling/). + * + * @param container The whole container will be shuffled. + */ + template <typename Container> + void shuffle(Container& container) noexcept; + +private: + static constexpr uint64_t rotl(uint64_t x, unsigned k) noexcept; + + uint64_t mX; + uint64_t mY; +}; + +/** + * @brief Main entry point to nanobench's benchmarking facility. + * + * It holds configuration and results from one or more benchmark runs. Usually it is used in a single line, where the object is + * constructed, configured, and then a benchmark is run. E.g. like this: + * + * ankerl::nanobench::Bench().unit("byte").batch(1000).run("random fluctuations", [&] { + * // here be the benchmark code + * }); + * + * In that example Bench() constructs the benchmark, it is then configured with unit() and batch(), and after configuration a + * benchmark is executed with run(). Once run() has finished, it prints the result to `std::cout`. It would also store the results + * in the Bench instance, but in this case the object is immediately destroyed so it's not available any more. + */ +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class Bench { +public: + /** + * @brief Creates a new benchmark for configuration and running of benchmarks. + */ + Bench(); + + Bench(Bench&& other); + Bench& operator=(Bench&& other); + Bench(Bench const& other); + Bench& operator=(Bench const& other); + ~Bench() noexcept; + + /*! + @brief Repeatedly calls `op()` based on the configuration, and performs measurements. + + This call is marked with `noinline` to prevent the compiler to optimize beyond different benchmarks. This can have quite a big + effect on benchmark accuracy. + + @verbatim embed:rst + .. note:: + + Each call to your lambda must have a side effect that the compiler can't possibly optimize it away. E.g. add a result to an + externally defined number (like `x` in the above example), and finally call `doNotOptimizeAway` on the variables the compiler + must not remove. You can also use :cpp:func:`ankerl::nanobench::doNotOptimizeAway` directly in the lambda, but be aware that + this has a small overhead. + + @endverbatim + + @tparam Op The code to benchmark. + */ + template <typename Op> + ANKERL_NANOBENCH(NOINLINE) + Bench& run(char const* benchmarkName, Op&& op); + + template <typename Op> + ANKERL_NANOBENCH(NOINLINE) + Bench& run(std::string const& benchmarkName, Op&& op); + + /** + * @brief Same as run(char const* benchmarkName, Op op), but instead uses the previously set name. + * @tparam Op The code to benchmark. + */ + template <typename Op> + ANKERL_NANOBENCH(NOINLINE) + Bench& run(Op&& op); + + /** + * @brief Title of the benchmark, will be shown in the table header. Changing the title will start a new markdown table. + * + * @param benchmarkTitle The title of the benchmark. + */ + Bench& title(char const* benchmarkTitle); + Bench& title(std::string const& benchmarkTitle); + ANKERL_NANOBENCH(NODISCARD) std::string const& title() const noexcept; + + /// Name of the benchmark, will be shown in the table row. + Bench& name(char const* benchmarkName); + Bench& name(std::string const& benchmarkName); + ANKERL_NANOBENCH(NODISCARD) std::string const& name() const noexcept; + + /** + * @brief Sets the batch size. + * + * E.g. number of processed byte, or some other metric for the size of the processed data in each iteration. If you benchmark + * hashing of a 1000 byte long string and want byte/sec as a result, you can specify 1000 as the batch size. + * + * @tparam T Any input type is internally cast to `double`. + * @param b batch size + */ + template <typename T> + Bench& batch(T b) noexcept; + ANKERL_NANOBENCH(NODISCARD) double batch() const noexcept; + + /** + * @brief Sets the operation unit. + * + * Defaults to "op". Could be e.g. "byte" for string processing. This is used for the table header, e.g. to show `ns/byte`. Use + * singular (*byte*, not *bytes*). A change clears the currently collected results. + * + * @param unit The unit name. + */ + Bench& unit(char const* unit); + Bench& unit(std::string const& unit); + ANKERL_NANOBENCH(NODISCARD) std::string const& unit() const noexcept; + + /** + * @brief Set the output stream where the resulting markdown table will be printed to. + * + * The default is `&std::cout`. You can disable all output by setting `nullptr`. + * + * @param outstream Pointer to output stream, can be `nullptr`. + */ + Bench& output(std::ostream* outstream) noexcept; + ANKERL_NANOBENCH(NODISCARD) std::ostream* output() const noexcept; + + /** + * Modern processors have a very accurate clock, being able to measure as low as 20 nanoseconds. This is the main trick nanobech to + * be so fast: we find out how accurate the clock is, then run the benchmark only so often that the clock's accuracy is good enough + * for accurate measurements. + * + * The default is to run one epoch for 1000 times the clock resolution. So for 20ns resolution and 11 epochs, this gives a total + * runtime of + * + * @f[ + * 20ns * 1000 * 11 \approx 0.2ms + * @f] + * + * To be precise, nanobench adds a 0-20% random noise to each evaluation. This is to prevent any aliasing effects, and further + * improves accuracy. + * + * Total runtime will be higher though: Some initial time is needed to find out the target number of iterations for each epoch, and + * there is some overhead involved to start & stop timers and calculate resulting statistics and writing the output. + * + * @param multiple Target number of times of clock resolution. Usually 1000 is a good compromise between runtime and accuracy. + */ + Bench& clockResolutionMultiple(size_t multiple) noexcept; + ANKERL_NANOBENCH(NODISCARD) size_t clockResolutionMultiple() const noexcept; + + /** + * @brief Controls number of epochs, the number of measurements to perform. + * + * The reported result will be the median of evaluation of each epoch. The higher you choose this, the more + * deterministic the result be and outliers will be more easily removed. Also the `err%` will be more accurate the higher this + * number is. Note that the `err%` will not necessarily decrease when number of epochs is increased. But it will be a more accurate + * representation of the benchmarked code's runtime stability. + * + * Choose the value wisely. In practice, 11 has been shown to be a reasonable choice between runtime performance and accuracy. + * This setting goes hand in hand with minEpocIterations() (or minEpochTime()). If you are more interested in *median* runtime, you + * might want to increase epochs(). If you are more interested in *mean* runtime, you might want to increase minEpochIterations() + * instead. + * + * @param numEpochs Number of epochs. + */ + Bench& epochs(size_t numEpochs) noexcept; + ANKERL_NANOBENCH(NODISCARD) size_t epochs() const noexcept; + + /** + * @brief Upper limit for the runtime of each epoch. + * + * As a safety precausion if the clock is not very accurate, we can set an upper limit for the maximum evaluation time per + * epoch. Default is 100ms. At least a single evaluation of the benchmark is performed. + * + * @see minEpochTime(), minEpochIterations() + * + * @param t Maximum target runtime for a single epoch. + */ + Bench& maxEpochTime(std::chrono::nanoseconds t) noexcept; + ANKERL_NANOBENCH(NODISCARD) std::chrono::nanoseconds maxEpochTime() const noexcept; + + /** + * @brief Minimum time each epoch should take. + * + * Default is zero, so we are fully relying on clockResolutionMultiple(). In most cases this is exactly what you want. If you see + * that the evaluation is unreliable with a high `err%`, you can increase either minEpochTime() or minEpochIterations(). + * + * @see maxEpochTime(), minEpochIterations() + * + * @param t Minimum time each epoch should take. + */ + Bench& minEpochTime(std::chrono::nanoseconds t) noexcept; + ANKERL_NANOBENCH(NODISCARD) std::chrono::nanoseconds minEpochTime() const noexcept; + + /** + * @brief Sets the minimum number of iterations each epoch should take. + * + * Default is 1, and we rely on clockResolutionMultiple(). If the `err%` is high and you want a more smooth result, you might want + * to increase the minimum number or iterations, or increase the minEpochTime(). + * + * @see minEpochTime(), maxEpochTime(), minEpochIterations() + * + * @param numIters Minimum number of iterations per epoch. + */ + Bench& minEpochIterations(uint64_t numIters) noexcept; + ANKERL_NANOBENCH(NODISCARD) uint64_t minEpochIterations() const noexcept; + + /** + * Sets exactly the number of iterations for each epoch. Ignores all other epoch limits. This forces nanobench to use exactly + * the given number of iterations for each epoch, not more and not less. Default is 0 (disabled). + * + * @param numIters Exact number of iterations to use. Set to 0 to disable. + */ + Bench& epochIterations(uint64_t numIters) noexcept; + ANKERL_NANOBENCH(NODISCARD) uint64_t epochIterations() const noexcept; + + /** + * @brief Sets a number of iterations that are initially performed without any measurements. + * + * Some benchmarks need a few evaluations to warm up caches / database / whatever access. Normally this should not be needed, since + * we show the median result so initial outliers will be filtered away automatically. If the warmup effect is large though, you + * might want to set it. Default is 0. + * + * @param numWarmupIters Number of warmup iterations. + */ + Bench& warmup(uint64_t numWarmupIters) noexcept; + ANKERL_NANOBENCH(NODISCARD) uint64_t warmup() const noexcept; + + /** + * @brief Marks the next run as the baseline. + * + * Call `relative(true)` to mark the run as the baseline. Successive runs will be compared to this run. It is calculated by + * + * @f[ + * 100\% * \frac{baseline}{runtime} + * @f] + * + * * 100% means it is exactly as fast as the baseline + * * >100% means it is faster than the baseline. E.g. 200% means the current run is twice as fast as the baseline. + * * <100% means it is slower than the baseline. E.g. 50% means it is twice as slow as the baseline. + * + * See the tutorial section "Comparing Results" for example usage. + * + * @param isRelativeEnabled True to enable processing + */ + Bench& relative(bool isRelativeEnabled) noexcept; + ANKERL_NANOBENCH(NODISCARD) bool relative() const noexcept; + + /** + * @brief Enables/disables performance counters. + * + * On Linux nanobench has a powerful feature to use performance counters. This enables counting of retired instructions, count + * number of branches, missed branches, etc. On default this is enabled, but you can disable it if you don't need that feature. + * + * @param showPerformanceCounters True to enable, false to disable. + */ + Bench& performanceCounters(bool showPerformanceCounters) noexcept; + ANKERL_NANOBENCH(NODISCARD) bool performanceCounters() const noexcept; + + /** + * @brief Retrieves all benchmark results collected by the bench object so far. + * + * Each call to run() generates a Result that is stored within the Bench instance. This is mostly for advanced users who want to + * see all the nitty gritty detials. + * + * @return All results collected so far. + */ + ANKERL_NANOBENCH(NODISCARD) std::vector<Result> const& results() const noexcept; + + /*! + @verbatim embed:rst + + Convenience shortcut to :cpp:func:`ankerl::nanobench::doNotOptimizeAway`. + + @endverbatim + */ + template <typename Arg> + Bench& doNotOptimizeAway(Arg&& arg); + + /*! + @verbatim embed:rst + + Sets N for asymptotic complexity calculation, so it becomes possible to calculate `Big O + <https://en.wikipedia.org/wiki/Big_O_notation>`_ from multiple benchmark evaluations. + + Use :cpp:func:`ankerl::nanobench::Bench::complexityBigO` when the evaluation has finished. See the tutorial + :ref:`asymptotic-complexity` for details. + + @endverbatim + + @tparam T Any type is cast to `double`. + @param b Length of N for the next benchmark run, so it is possible to calculate `bigO`. + */ + template <typename T> + Bench& complexityN(T b) noexcept; + ANKERL_NANOBENCH(NODISCARD) double complexityN() const noexcept; + + /*! + Calculates [Big O](https://en.wikipedia.org/wiki/Big_O_notation>) of the results with all preconfigured complexity functions. + Currently these complexity functions are fitted into the benchmark results: + + @f$ \mathcal{O}(1) @f$, + @f$ \mathcal{O}(n) @f$, + @f$ \mathcal{O}(\log{}n) @f$, + @f$ \mathcal{O}(n\log{}n) @f$, + @f$ \mathcal{O}(n^2) @f$, + @f$ \mathcal{O}(n^3) @f$. + + If we e.g. evaluate the complexity of `std::sort`, this is the result of `std::cout << bench.complexityBigO()`: + + ``` + | coefficient | err% | complexity + |--------------:|-------:|------------ + | 5.08935e-09 | 2.6% | O(n log n) + | 6.10608e-08 | 8.0% | O(n) + | 1.29307e-11 | 47.2% | O(n^2) + | 2.48677e-15 | 69.6% | O(n^3) + | 9.88133e-06 | 132.3% | O(log n) + | 5.98793e-05 | 162.5% | O(1) + ``` + + So in this case @f$ \mathcal{O}(n\log{}n) @f$ provides the best approximation. + + @verbatim embed:rst + See the tutorial :ref:`asymptotic-complexity` for details. + @endverbatim + @return Evaluation results, which can be printed or otherwise inspected. + */ + std::vector<BigO> complexityBigO() const; + + /** + * @brief Calculates bigO for a custom function. + * + * E.g. to calculate the mean squared error for @f$ \mathcal{O}(\log{}\log{}n) @f$, which is not part of the default set of + * complexityBigO(), you can do this: + * + * ``` + * auto logLogN = bench.complexityBigO("O(log log n)", [](double n) { + * return std::log2(std::log2(n)); + * }); + * ``` + * + * The resulting mean squared error can be printed with `std::cout << logLogN`. E.g. it prints something like this: + * + * ```text + * 2.46985e-05 * O(log log n), rms=1.48121 + * ``` + * + * @tparam Op Type of mapping operation. + * @param name Name for the function, e.g. "O(log log n)" + * @param op Op's operator() maps a `double` with the desired complexity function, e.g. `log2(log2(n))`. + * @return BigO Error calculation, which is streamable to std::cout. + */ + template <typename Op> + BigO complexityBigO(char const* name, Op op) const; + + template <typename Op> + BigO complexityBigO(std::string const& name, Op op) const; + + /*! + @verbatim embed:rst + + Convenience shortcut to :cpp:func:`ankerl::nanobench::render`. + + @endverbatim + */ + Bench& render(char const* templateContent, std::ostream& os); + + Bench& config(Config const& benchmarkConfig); + ANKERL_NANOBENCH(NODISCARD) Config const& config() const noexcept; + +private: + Config mConfig{}; + std::vector<Result> mResults{}; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +/** + * @brief Makes sure none of the given arguments are optimized away by the compiler. + * + * @tparam Arg Type of the argument that shouldn't be optimized away. + * @param arg The input that we mark as being used, even though we don't do anything with it. + */ +template <typename Arg> +void doNotOptimizeAway(Arg&& arg); + +namespace detail { + +#if defined(_MSC_VER) +void doNotOptimizeAwaySink(void const*); + +template <typename T> +void doNotOptimizeAway(T const& val); + +#else + +// see folly's Benchmark.h +template <typename T> +constexpr bool doNotOptimizeNeedsIndirect() { + using Decayed = typename std::decay<T>::type; + return !ANKERL_NANOBENCH_IS_TRIVIALLY_COPYABLE(Decayed) || sizeof(Decayed) > sizeof(long) || std::is_pointer<Decayed>::value; +} + +template <typename T> +typename std::enable_if<!doNotOptimizeNeedsIndirect<T>()>::type doNotOptimizeAway(T const& val) { + // NOLINTNEXTLINE(hicpp-no-assembler) + asm volatile("" ::"r"(val)); +} + +template <typename T> +typename std::enable_if<doNotOptimizeNeedsIndirect<T>()>::type doNotOptimizeAway(T const& val) { + // NOLINTNEXTLINE(hicpp-no-assembler) + asm volatile("" ::"m"(val) : "memory"); +} +#endif + +// internally used, but visible because run() is templated. +// Not movable/copy-able, so we simply use a pointer instead of unique_ptr. This saves us from +// having to include <memory>, and the template instantiation overhead of unique_ptr which is unfortunately quite significant. +ANKERL_NANOBENCH(IGNORE_EFFCPP_PUSH) +class IterationLogic { +public: + explicit IterationLogic(Bench const& config) noexcept; + ~IterationLogic(); + + ANKERL_NANOBENCH(NODISCARD) uint64_t numIters() const noexcept; + void add(std::chrono::nanoseconds elapsed, PerformanceCounters const& pc) noexcept; + void moveResultTo(std::vector<Result>& results) noexcept; + +private: + struct Impl; + Impl* mPimpl; +}; +ANKERL_NANOBENCH(IGNORE_EFFCPP_POP) + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class PerformanceCounters { +public: + PerformanceCounters(PerformanceCounters const&) = delete; + PerformanceCounters& operator=(PerformanceCounters const&) = delete; + + PerformanceCounters(); + ~PerformanceCounters(); + + void beginMeasure(); + void endMeasure(); + void updateResults(uint64_t numIters); + + ANKERL_NANOBENCH(NODISCARD) PerfCountSet<uint64_t> const& val() const noexcept; + ANKERL_NANOBENCH(NODISCARD) PerfCountSet<bool> const& has() const noexcept; + +private: +#if ANKERL_NANOBENCH(PERF_COUNTERS) + LinuxPerformanceCounters* mPc = nullptr; +#endif + PerfCountSet<uint64_t> mVal{}; + PerfCountSet<bool> mHas{}; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +// Gets the singleton +PerformanceCounters& performanceCounters(); + +} // namespace detail + +class BigO { +public: + using RangeMeasure = std::vector<std::pair<double, double>>; + + template <typename Op> + static RangeMeasure mapRangeMeasure(RangeMeasure data, Op op) { + for (auto& rangeMeasure : data) { + rangeMeasure.first = op(rangeMeasure.first); + } + return data; + } + + static RangeMeasure collectRangeMeasure(std::vector<Result> const& results); + + template <typename Op> + BigO(char const* bigOName, RangeMeasure const& rangeMeasure, Op rangeToN) + : BigO(bigOName, mapRangeMeasure(rangeMeasure, rangeToN)) {} + + template <typename Op> + BigO(std::string const& bigOName, RangeMeasure const& rangeMeasure, Op rangeToN) + : BigO(bigOName, mapRangeMeasure(rangeMeasure, rangeToN)) {} + + BigO(char const* bigOName, RangeMeasure const& scaledRangeMeasure); + BigO(std::string const& bigOName, RangeMeasure const& scaledRangeMeasure); + ANKERL_NANOBENCH(NODISCARD) std::string const& name() const noexcept; + ANKERL_NANOBENCH(NODISCARD) double constant() const noexcept; + ANKERL_NANOBENCH(NODISCARD) double normalizedRootMeanSquare() const noexcept; + ANKERL_NANOBENCH(NODISCARD) bool operator<(BigO const& other) const noexcept; + +private: + std::string mName{}; + double mConstant{}; + double mNormalizedRootMeanSquare{}; +}; +std::ostream& operator<<(std::ostream& os, BigO const& bigO); +std::ostream& operator<<(std::ostream& os, std::vector<ankerl::nanobench::BigO> const& bigOs); + +} // namespace nanobench +} // namespace ankerl + +// implementation ///////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +constexpr uint64_t(Rng::min)() { + return 0; +} + +constexpr uint64_t(Rng::max)() { + return (std::numeric_limits<uint64_t>::max)(); +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer") +uint64_t Rng::operator()() noexcept { + auto x = mX; + + mX = UINT64_C(15241094284759029579) * mY; + mY = rotl(mY - x, 27); + + return x; +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer") +uint32_t Rng::bounded(uint32_t range) noexcept { + uint64_t r32 = static_cast<uint32_t>(operator()()); + auto multiresult = r32 * range; + return static_cast<uint32_t>(multiresult >> 32U); +} + +double Rng::uniform01() noexcept { + auto i = (UINT64_C(0x3ff) << 52U) | (operator()() >> 12U); + // can't use union in c++ here for type puning, it's undefined behavior. + // std::memcpy is optimized anyways. + double d; + std::memcpy(&d, &i, sizeof(double)); + return d - 1.0; +} + +template <typename Container> +void Rng::shuffle(Container& container) noexcept { + auto size = static_cast<uint32_t>(container.size()); + for (auto i = size; i > 1U; --i) { + using std::swap; + auto p = bounded(i); // number in [0, i) + swap(container[i - 1], container[p]); + } +} + +constexpr uint64_t Rng::rotl(uint64_t x, unsigned k) noexcept { + return (x << k) | (x >> (64U - k)); +} + +template <typename Op> +ANKERL_NANOBENCH_NO_SANITIZE("integer") +Bench& Bench::run(Op&& op) { + // It is important that this method is kept short so the compiler can do better optimizations/ inlining of op() + detail::IterationLogic iterationLogic(*this); + auto& pc = detail::performanceCounters(); + + while (auto n = iterationLogic.numIters()) { + pc.beginMeasure(); + Clock::time_point before = Clock::now(); + while (n-- > 0) { + op(); + } + Clock::time_point after = Clock::now(); + pc.endMeasure(); + pc.updateResults(iterationLogic.numIters()); + iterationLogic.add(after - before, pc); + } + iterationLogic.moveResultTo(mResults); + return *this; +} + +// Performs all evaluations. +template <typename Op> +Bench& Bench::run(char const* benchmarkName, Op&& op) { + name(benchmarkName); + return run(std::forward<Op>(op)); +} + +template <typename Op> +Bench& Bench::run(std::string const& benchmarkName, Op&& op) { + name(benchmarkName); + return run(std::forward<Op>(op)); +} + +template <typename Op> +BigO Bench::complexityBigO(char const* benchmarkName, Op op) const { + return BigO(benchmarkName, BigO::collectRangeMeasure(mResults), op); +} + +template <typename Op> +BigO Bench::complexityBigO(std::string const& benchmarkName, Op op) const { + return BigO(benchmarkName, BigO::collectRangeMeasure(mResults), op); +} + +// Set the batch size, e.g. number of processed bytes, or some other metric for the size of the processed data in each iteration. +// Any argument is cast to double. +template <typename T> +Bench& Bench::batch(T b) noexcept { + mConfig.mBatch = static_cast<double>(b); + return *this; +} + +// Sets the computation complexity of the next run. Any argument is cast to double. +template <typename T> +Bench& Bench::complexityN(T n) noexcept { + mConfig.mComplexityN = static_cast<double>(n); + return *this; +} + +// Convenience: makes sure none of the given arguments are optimized away by the compiler. +template <typename Arg> +Bench& Bench::doNotOptimizeAway(Arg&& arg) { + detail::doNotOptimizeAway(std::forward<Arg>(arg)); + return *this; +} + +// Makes sure none of the given arguments are optimized away by the compiler. +template <typename Arg> +void doNotOptimizeAway(Arg&& arg) { + detail::doNotOptimizeAway(std::forward<Arg>(arg)); +} + +namespace detail { + +#if defined(_MSC_VER) +template <typename T> +void doNotOptimizeAway(T const& val) { + doNotOptimizeAwaySink(&val); +} + +#endif + +} // namespace detail +} // namespace nanobench +} // namespace ankerl + +#if defined(ANKERL_NANOBENCH_IMPLEMENT) + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// implementation part - only visible in .cpp +/////////////////////////////////////////////////////////////////////////////////////////////////// + +# include <algorithm> // sort, reverse +# include <atomic> // compare_exchange_strong in loop overhead +# include <cstdlib> // getenv +# include <cstring> // strstr, strncmp +# include <fstream> // ifstream to parse proc files +# include <iomanip> // setw, setprecision +# include <iostream> // cout +# include <numeric> // accumulate +# include <random> // random_device +# include <sstream> // to_s in Number +# include <stdexcept> // throw for rendering templates +# include <tuple> // std::tie +# if defined(__linux__) +# include <unistd.h> //sysconf +# endif +# if ANKERL_NANOBENCH(PERF_COUNTERS) +# include <map> // map + +# include <linux/perf_event.h> +# include <sys/ioctl.h> +# include <sys/syscall.h> +# include <unistd.h> +# endif + +// declarations /////////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +// helper stuff that is only intended to be used internally +namespace detail { + +struct TableInfo; + +// formatting utilities +namespace fmt { + +class NumSep; +class StreamStateRestorer; +class Number; +class MarkDownColumn; +class MarkDownCode; + +} // namespace fmt +} // namespace detail +} // namespace nanobench +} // namespace ankerl + +// definitions //////////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +uint64_t splitMix64(uint64_t& state) noexcept; + +namespace detail { + +// helpers to get double values +template <typename T> +inline double d(T t) noexcept { + return static_cast<double>(t); +} +inline double d(Clock::duration duration) noexcept { + return std::chrono::duration_cast<std::chrono::duration<double>>(duration).count(); +} + +// Calculates clock resolution once, and remembers the result +inline Clock::duration clockResolution() noexcept; + +} // namespace detail + +namespace templates { + +char const* csv() noexcept { + return R"DELIM("title";"name";"unit";"batch";"elapsed";"error %";"instructions";"branches";"branch misses";"total" +{{#result}}"{{title}}";"{{name}}";"{{unit}}";{{batch}};{{median(elapsed)}};{{medianAbsolutePercentError(elapsed)}};{{median(instructions)}};{{median(branchinstructions)}};{{median(branchmisses)}};{{sumProduct(iterations, elapsed)}} +{{/result}})DELIM"; +} + +char const* htmlBoxplot() noexcept { + return R"DELIM(<html> + +<head> + <script src="https://cdn.plot.ly/plotly-latest.min.js"></script> +</head> + +<body> + <div id="myDiv"></div> + <script> + var data = [ + {{#result}}{ + name: '{{name}}', + y: [{{#measurement}}{{elapsed}}{{^-last}}, {{/last}}{{/measurement}}], + }, + {{/result}} + ]; + var title = '{{title}}'; + + data = data.map(a => Object.assign(a, { boxpoints: 'all', pointpos: 0, type: 'box' })); + var layout = { title: { text: title }, showlegend: false, yaxis: { title: 'time per unit', rangemode: 'tozero', autorange: true } }; Plotly.newPlot('myDiv', data, layout, {responsive: true}); + </script> +</body> + +</html>)DELIM"; +} + +char const* json() noexcept { + return R"DELIM({ + "results": [ +{{#result}} { + "title": "{{title}}", + "name": "{{name}}", + "unit": "{{unit}}", + "batch": {{batch}}, + "complexityN": {{complexityN}}, + "epochs": {{epochs}}, + "clockResolution": {{clockResolution}}, + "clockResolutionMultiple": {{clockResolutionMultiple}}, + "maxEpochTime": {{maxEpochTime}}, + "minEpochTime": {{minEpochTime}}, + "minEpochIterations": {{minEpochIterations}}, + "epochIterations": {{epochIterations}}, + "warmup": {{warmup}}, + "relative": {{relative}}, + "median(elapsed)": {{median(elapsed)}}, + "medianAbsolutePercentError(elapsed)": {{medianAbsolutePercentError(elapsed)}}, + "median(instructions)": {{median(instructions)}}, + "medianAbsolutePercentError(instructions)": {{medianAbsolutePercentError(instructions)}}, + "median(cpucycles)": {{median(cpucycles)}}, + "median(contextswitches)": {{median(contextswitches)}}, + "median(pagefaults)": {{median(pagefaults)}}, + "median(branchinstructions)": {{median(branchinstructions)}}, + "median(branchmisses)": {{median(branchmisses)}}, + "totalTime": {{sumProduct(iterations, elapsed)}}, + "measurements": [ +{{#measurement}} { + "iterations": {{iterations}}, + "elapsed": {{elapsed}}, + "pagefaults": {{pagefaults}}, + "cpucycles": {{cpucycles}}, + "contextswitches": {{contextswitches}}, + "instructions": {{instructions}}, + "branchinstructions": {{branchinstructions}}, + "branchmisses": {{branchmisses}} + }{{^-last}},{{/-last}} +{{/measurement}} ] + }{{^-last}},{{/-last}} +{{/result}} ] +})DELIM"; +} + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +struct Node { + enum class Type { tag, content, section, inverted_section }; + + char const* begin; + char const* end; + std::vector<Node> children; + Type type; + + template <size_t N> + // NOLINTNEXTLINE(hicpp-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + bool operator==(char const (&str)[N]) const noexcept { + return static_cast<size_t>(std::distance(begin, end) + 1) == N && 0 == strncmp(str, begin, N - 1); + } +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +static std::vector<Node> parseMustacheTemplate(char const** tpl) { + std::vector<Node> nodes; + + while (true) { + auto begin = std::strstr(*tpl, "{{"); + auto end = begin; + if (begin != nullptr) { + begin += 2; + end = std::strstr(begin, "}}"); + } + + if (begin == nullptr || end == nullptr) { + // nothing found, finish node + nodes.emplace_back(Node{*tpl, *tpl + std::strlen(*tpl), std::vector<Node>{}, Node::Type::content}); + return nodes; + } + + nodes.emplace_back(Node{*tpl, begin - 2, std::vector<Node>{}, Node::Type::content}); + + // we found a tag + *tpl = end + 2; + switch (*begin) { + case '/': + // finished! bail out + return nodes; + + case '#': + nodes.emplace_back(Node{begin + 1, end, parseMustacheTemplate(tpl), Node::Type::section}); + break; + + case '^': + nodes.emplace_back(Node{begin + 1, end, parseMustacheTemplate(tpl), Node::Type::inverted_section}); + break; + + default: + nodes.emplace_back(Node{begin, end, std::vector<Node>{}, Node::Type::tag}); + break; + } + } +} + +static bool generateFirstLast(Node const& n, size_t idx, size_t size, std::ostream& out) { + bool matchFirst = n == "-first"; + bool matchLast = n == "-last"; + if (!matchFirst && !matchLast) { + return false; + } + + bool doWrite = false; + if (n.type == Node::Type::section) { + doWrite = (matchFirst && idx == 0) || (matchLast && idx == size - 1); + } else if (n.type == Node::Type::inverted_section) { + doWrite = (matchFirst && idx != 0) || (matchLast && idx != size - 1); + } + + if (doWrite) { + for (auto const& child : n.children) { + if (child.type == Node::Type::content) { + out.write(child.begin, std::distance(child.begin, child.end)); + } + } + } + return true; +} + +static bool matchCmdArgs(std::string const& str, std::vector<std::string>& matchResult) { + matchResult.clear(); + auto idxOpen = str.find('('); + auto idxClose = str.find(')', idxOpen); + if (idxClose == std::string::npos) { + return false; + } + + matchResult.emplace_back(str.substr(0, idxOpen)); + + // split by comma + matchResult.emplace_back(std::string{}); + for (size_t i = idxOpen + 1; i != idxClose; ++i) { + if (str[i] == ' ' || str[i] == '\t') { + // skip whitespace + continue; + } + if (str[i] == ',') { + // got a comma => new string + matchResult.emplace_back(std::string{}); + continue; + } + // no whitespace no comma, append + matchResult.back() += str[i]; + } + return true; +} + +static bool generateConfigTag(Node const& n, Config const& config, std::ostream& out) { + using detail::d; + + if (n == "title") { + out << config.mBenchmarkTitle; + return true; + } else if (n == "name") { + out << config.mBenchmarkName; + return true; + } else if (n == "unit") { + out << config.mUnit; + return true; + } else if (n == "batch") { + out << config.mBatch; + return true; + } else if (n == "complexityN") { + out << config.mComplexityN; + return true; + } else if (n == "epochs") { + out << config.mNumEpochs; + return true; + } else if (n == "clockResolution") { + out << d(detail::clockResolution()); + return true; + } else if (n == "clockResolutionMultiple") { + out << config.mClockResolutionMultiple; + return true; + } else if (n == "maxEpochTime") { + out << d(config.mMaxEpochTime); + return true; + } else if (n == "minEpochTime") { + out << d(config.mMinEpochTime); + return true; + } else if (n == "minEpochIterations") { + out << config.mMinEpochIterations; + return true; + } else if (n == "epochIterations") { + out << config.mEpochIterations; + return true; + } else if (n == "warmup") { + out << config.mWarmup; + return true; + } else if (n == "relative") { + out << config.mIsRelative; + return true; + } + return false; +} + +static std::ostream& generateResultTag(Node const& n, Result const& r, std::ostream& out) { + if (generateConfigTag(n, r.config(), out)) { + return out; + } + // match e.g. "median(elapsed)" + // g++ 4.8 doesn't implement std::regex :( + // static std::regex const regOpArg1("^([a-zA-Z]+)\\(([a-zA-Z]*)\\)$"); + // std::cmatch matchResult; + // if (std::regex_match(n.begin, n.end, matchResult, regOpArg1)) { + std::vector<std::string> matchResult; + if (matchCmdArgs(std::string(n.begin, n.end), matchResult)) { + if (matchResult.size() == 2) { + auto m = Result::fromString(matchResult[1]); + if (m == Result::Measure::_size) { + return out << 0.0; + } + + if (matchResult[0] == "median") { + return out << r.median(m); + } + if (matchResult[0] == "average") { + return out << r.average(m); + } + if (matchResult[0] == "medianAbsolutePercentError") { + return out << r.medianAbsolutePercentError(m); + } + if (matchResult[0] == "sum") { + return out << r.sum(m); + } + if (matchResult[0] == "minimum") { + return out << r.minimum(m); + } + if (matchResult[0] == "maximum") { + return out << r.maximum(m); + } + } else if (matchResult.size() == 3) { + auto m1 = Result::fromString(matchResult[1]); + auto m2 = Result::fromString(matchResult[2]); + if (m1 == Result::Measure::_size || m2 == Result::Measure::_size) { + return out << 0.0; + } + + if (matchResult[0] == "sumProduct") { + return out << r.sumProduct(m1, m2); + } + } + } + + // match e.g. "sumProduct(elapsed, iterations)" + // static std::regex const regOpArg2("^([a-zA-Z]+)\\(([a-zA-Z]*)\\s*,\\s+([a-zA-Z]*)\\)$"); + + // nothing matches :( + throw std::runtime_error("command '" + std::string(n.begin, n.end) + "' not understood"); +} + +static void generateResultMeasurement(std::vector<Node> const& nodes, size_t idx, Result const& r, std::ostream& out) { + for (auto const& n : nodes) { + if (!generateFirstLast(n, idx, r.size(), out)) { + ANKERL_NANOBENCH_LOG("n.type=" << static_cast<int>(n.type)); + switch (n.type) { + case Node::Type::content: + out.write(n.begin, std::distance(n.begin, n.end)); + break; + + case Node::Type::inverted_section: + throw std::runtime_error("got a inverted section inside measurement"); + + case Node::Type::section: + throw std::runtime_error("got a section inside measurement"); + + case Node::Type::tag: { + auto m = Result::fromString(std::string(n.begin, n.end)); + if (m == Result::Measure::_size || !r.has(m)) { + out << 0.0; + } else { + out << r.get(idx, m); + } + break; + } + } + } + } +} + +static void generateResult(std::vector<Node> const& nodes, size_t idx, std::vector<Result> const& results, std::ostream& out) { + auto const& r = results[idx]; + for (auto const& n : nodes) { + if (!generateFirstLast(n, idx, results.size(), out)) { + ANKERL_NANOBENCH_LOG("n.type=" << static_cast<int>(n.type)); + switch (n.type) { + case Node::Type::content: + out.write(n.begin, std::distance(n.begin, n.end)); + break; + + case Node::Type::inverted_section: + throw std::runtime_error("got a inverted section inside result"); + + case Node::Type::section: + if (n == "measurement") { + for (size_t i = 0; i < r.size(); ++i) { + generateResultMeasurement(n.children, i, r, out); + } + } else { + throw std::runtime_error("got a section inside result"); + } + break; + + case Node::Type::tag: + generateResultTag(n, r, out); + break; + } + } + } +} + +} // namespace templates + +// helper stuff that only intended to be used internally +namespace detail { + +char const* getEnv(char const* name); +bool isEndlessRunning(std::string const& name); + +template <typename T> +T parseFile(std::string const& filename); + +void gatherStabilityInformation(std::vector<std::string>& warnings, std::vector<std::string>& recommendations); +void printStabilityInformationOnce(std::ostream* os); + +// remembers the last table settings used. When it changes, a new table header is automatically written for the new entry. +uint64_t& singletonHeaderHash() noexcept; + +// determines resolution of the given clock. This is done by measuring multiple times and returning the minimum time difference. +Clock::duration calcClockResolution(size_t numEvaluations) noexcept; + +// formatting utilities +namespace fmt { + +// adds thousands separator to numbers +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class NumSep : public std::numpunct<char> { +public: + explicit NumSep(char sep); + char do_thousands_sep() const override; + std::string do_grouping() const override; + +private: + char mSep; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +// RAII to save & restore a stream's state +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class StreamStateRestorer { +public: + explicit StreamStateRestorer(std::ostream& s); + ~StreamStateRestorer(); + + // sets back all stream info that we remembered at construction + void restore(); + + // don't allow copying / moving + StreamStateRestorer(StreamStateRestorer const&) = delete; + StreamStateRestorer& operator=(StreamStateRestorer const&) = delete; + StreamStateRestorer(StreamStateRestorer&&) = delete; + StreamStateRestorer& operator=(StreamStateRestorer&&) = delete; + +private: + std::ostream& mStream; + std::locale mLocale; + std::streamsize const mPrecision; + std::streamsize const mWidth; + std::ostream::char_type const mFill; + std::ostream::fmtflags const mFmtFlags; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +// Number formatter +class Number { +public: + Number(int width, int precision, double value); + Number(int width, int precision, int64_t value); + std::string to_s() const; + +private: + friend std::ostream& operator<<(std::ostream& os, Number const& n); + std::ostream& write(std::ostream& os) const; + + int mWidth; + int mPrecision; + double mValue; +}; + +// helper replacement for std::to_string of signed/unsigned numbers so we are locale independent +std::string to_s(uint64_t s); + +std::ostream& operator<<(std::ostream& os, Number const& n); + +class MarkDownColumn { +public: + MarkDownColumn(int w, int prec, std::string const& tit, std::string const& suff, double val); + std::string title() const; + std::string separator() const; + std::string invalid() const; + std::string value() const; + +private: + int mWidth; + int mPrecision; + std::string mTitle; + std::string mSuffix; + double mValue; +}; + +// Formats any text as markdown code, escaping backticks. +class MarkDownCode { +public: + explicit MarkDownCode(std::string const& what); + +private: + friend std::ostream& operator<<(std::ostream& os, MarkDownCode const& mdCode); + std::ostream& write(std::ostream& os) const; + + std::string mWhat{}; +}; + +std::ostream& operator<<(std::ostream& os, MarkDownCode const& mdCode); + +} // namespace fmt +} // namespace detail +} // namespace nanobench +} // namespace ankerl + +// implementation ///////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +void render(char const* mustacheTemplate, std::vector<Result> const& results, std::ostream& out) { + detail::fmt::StreamStateRestorer restorer(out); + + out.precision(std::numeric_limits<double>::digits10); + auto nodes = templates::parseMustacheTemplate(&mustacheTemplate); + + for (auto const& n : nodes) { + ANKERL_NANOBENCH_LOG("n.type=" << static_cast<int>(n.type)); + switch (n.type) { + case templates::Node::Type::content: + out.write(n.begin, std::distance(n.begin, n.end)); + break; + + case templates::Node::Type::inverted_section: + throw std::runtime_error("unknown list '" + std::string(n.begin, n.end) + "'"); + + case templates::Node::Type::section: + if (n == "result") { + const size_t nbResults = results.size(); + for (size_t i = 0; i < nbResults; ++i) { + generateResult(n.children, i, results, out); + } + } else { + throw std::runtime_error("unknown section '" + std::string(n.begin, n.end) + "'"); + } + break; + + case templates::Node::Type::tag: + // This just uses the last result's config. + if (!generateConfigTag(n, results.back().config(), out)) { + throw std::runtime_error("unknown tag '" + std::string(n.begin, n.end) + "'"); + } + break; + } + } +} + +void render(char const* mustacheTemplate, const Bench& bench, std::ostream& out) { + render(mustacheTemplate, bench.results(), out); +} + +namespace detail { + +PerformanceCounters& performanceCounters() { +# if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +# endif + static PerformanceCounters pc; +# if defined(__clang__) +# pragma clang diagnostic pop +# endif + return pc; +} + +// Windows version of doNotOptimizeAway +// see https://github.com/google/benchmark/blob/master/include/benchmark/benchmark.h#L307 +// see https://github.com/facebook/folly/blob/master/folly/Benchmark.h#L280 +// see https://docs.microsoft.com/en-us/cpp/preprocessor/optimize +# if defined(_MSC_VER) +# pragma optimize("", off) +void doNotOptimizeAwaySink(void const*) {} +# pragma optimize("", on) +# endif + +template <typename T> +T parseFile(std::string const& filename) { + std::ifstream fin(filename); + T num{}; + fin >> num; + return num; +} + +char const* getEnv(char const* name) { +# if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4996) // getenv': This function or variable may be unsafe. +# endif + return std::getenv(name); +# if defined(_MSC_VER) +# pragma warning(pop) +# endif +} + +bool isEndlessRunning(std::string const& name) { + auto endless = getEnv("NANOBENCH_ENDLESS"); + return nullptr != endless && endless == name; +} + +void gatherStabilityInformation(std::vector<std::string>& warnings, std::vector<std::string>& recommendations) { + warnings.clear(); + recommendations.clear(); + + bool recommendCheckFlags = false; + +# if defined(DEBUG) + warnings.emplace_back("DEBUG defined"); + recommendCheckFlags = true; +# endif + + bool recommendPyPerf = false; +# if defined(__linux__) + auto nprocs = sysconf(_SC_NPROCESSORS_CONF); + if (nprocs <= 0) { + warnings.emplace_back("couldn't figure out number of processors - no governor, turbo check possible"); + } else { + + // check frequency scaling + for (long id = 0; id < nprocs; ++id) { + auto idStr = detail::fmt::to_s(static_cast<uint64_t>(id)); + auto sysCpu = "/sys/devices/system/cpu/cpu" + idStr; + auto minFreq = parseFile<int64_t>(sysCpu + "/cpufreq/scaling_min_freq"); + auto maxFreq = parseFile<int64_t>(sysCpu + "/cpufreq/scaling_max_freq"); + if (minFreq != maxFreq) { + auto minMHz = static_cast<double>(minFreq) / 1000.0; + auto maxMHz = static_cast<double>(maxFreq) / 1000.0; + warnings.emplace_back("CPU frequency scaling enabled: CPU " + idStr + " between " + + detail::fmt::Number(1, 1, minMHz).to_s() + " and " + detail::fmt::Number(1, 1, maxMHz).to_s() + + " MHz"); + recommendPyPerf = true; + break; + } + } + + auto currentGovernor = parseFile<std::string>("/sys/devices/system/cpu/cpu0/cpufreq/scaling_governor"); + if ("performance" != currentGovernor) { + warnings.emplace_back("CPU governor is '" + currentGovernor + "' but should be 'performance'"); + recommendPyPerf = true; + } + + if (0 == parseFile<int>("/sys/devices/system/cpu/intel_pstate/no_turbo")) { + warnings.emplace_back("Turbo is enabled, CPU frequency will fluctuate"); + recommendPyPerf = true; + } + } +# endif + + if (recommendCheckFlags) { + recommendations.emplace_back("Make sure you compile for Release"); + } + if (recommendPyPerf) { + recommendations.emplace_back("Use 'pyperf system tune' before benchmarking. See https://github.com/vstinner/pyperf"); + } +} + +void printStabilityInformationOnce(std::ostream* outStream) { + static bool shouldPrint = true; + if (shouldPrint && outStream) { + auto& os = *outStream; + shouldPrint = false; + std::vector<std::string> warnings; + std::vector<std::string> recommendations; + gatherStabilityInformation(warnings, recommendations); + if (warnings.empty()) { + return; + } + + os << "Warning, results might be unstable:" << std::endl; + for (auto const& w : warnings) { + os << "* " << w << std::endl; + } + + os << std::endl << "Recommendations" << std::endl; + for (auto const& r : recommendations) { + os << "* " << r << std::endl; + } + } +} + +// remembers the last table settings used. When it changes, a new table header is automatically written for the new entry. +uint64_t& singletonHeaderHash() noexcept { + static uint64_t sHeaderHash{}; + return sHeaderHash; +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer") +inline uint64_t fnv1a(std::string const& str) noexcept { + auto val = UINT64_C(14695981039346656037); + for (auto c : str) { + val = (val ^ static_cast<uint8_t>(c)) * UINT64_C(1099511628211); + } + return val; +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer") +inline uint64_t hash_combine(uint64_t seed, uint64_t val) { + return seed ^ (val + UINT64_C(0x9e3779b9) + (seed << 6U) + (seed >> 2U)); +} + +// determines resolution of the given clock. This is done by measuring multiple times and returning the minimum time difference. +Clock::duration calcClockResolution(size_t numEvaluations) noexcept { + auto bestDuration = Clock::duration::max(); + Clock::time_point tBegin; + Clock::time_point tEnd; + for (size_t i = 0; i < numEvaluations; ++i) { + tBegin = Clock::now(); + do { + tEnd = Clock::now(); + } while (tBegin == tEnd); + bestDuration = (std::min)(bestDuration, tEnd - tBegin); + } + return bestDuration; +} + +// Calculates clock resolution once, and remembers the result +Clock::duration clockResolution() noexcept { + static Clock::duration sResolution = calcClockResolution(20); + return sResolution; +} + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +struct IterationLogic::Impl { + enum class State { warmup, upscaling_runtime, measuring, endless }; + + explicit Impl(Bench const& bench) + : mBench(bench) + , mResult(bench.config()) { + printStabilityInformationOnce(mBench.output()); + + // determine target runtime per epoch + mTargetRuntimePerEpoch = detail::clockResolution() * mBench.clockResolutionMultiple(); + if (mTargetRuntimePerEpoch > mBench.maxEpochTime()) { + mTargetRuntimePerEpoch = mBench.maxEpochTime(); + } + if (mTargetRuntimePerEpoch < mBench.minEpochTime()) { + mTargetRuntimePerEpoch = mBench.minEpochTime(); + } + + if (isEndlessRunning(mBench.name())) { + std::cerr << "NANOBENCH_ENDLESS set: running '" << mBench.name() << "' endlessly" << std::endl; + mNumIters = (std::numeric_limits<uint64_t>::max)(); + mState = State::endless; + } else if (0 != mBench.warmup()) { + mNumIters = mBench.warmup(); + mState = State::warmup; + } else if (0 != mBench.epochIterations()) { + // exact number of iterations + mNumIters = mBench.epochIterations(); + mState = State::measuring; + } else { + mNumIters = mBench.minEpochIterations(); + mState = State::upscaling_runtime; + } + } + + // directly calculates new iters based on elapsed&iters, and adds a 10% noise. Makes sure we don't underflow. + ANKERL_NANOBENCH(NODISCARD) uint64_t calcBestNumIters(std::chrono::nanoseconds elapsed, uint64_t iters) noexcept { + auto doubleElapsed = d(elapsed); + auto doubleTargetRuntimePerEpoch = d(mTargetRuntimePerEpoch); + auto doubleNewIters = doubleTargetRuntimePerEpoch / doubleElapsed * d(iters); + + auto doubleMinEpochIters = d(mBench.minEpochIterations()); + if (doubleNewIters < doubleMinEpochIters) { + doubleNewIters = doubleMinEpochIters; + } + doubleNewIters *= 1.0 + 0.2 * mRng.uniform01(); + + // +0.5 for correct rounding when casting + // NOLINTNEXTLINE(bugprone-incorrect-roundings) + return static_cast<uint64_t>(doubleNewIters + 0.5); + } + + ANKERL_NANOBENCH_NO_SANITIZE("integer") void upscale(std::chrono::nanoseconds elapsed) { + if (elapsed * 10 < mTargetRuntimePerEpoch) { + // we are far below the target runtime. Multiply iterations by 10 (with overflow check) + if (mNumIters * 10 < mNumIters) { + // overflow :-( + showResult("iterations overflow. Maybe your code got optimized away?"); + mNumIters = 0; + return; + } + mNumIters *= 10; + } else { + mNumIters = calcBestNumIters(elapsed, mNumIters); + } + } + + void add(std::chrono::nanoseconds elapsed, PerformanceCounters const& pc) noexcept { +# if defined(ANKERL_NANOBENCH_LOG_ENABLED) + auto oldIters = mNumIters; +# endif + + switch (mState) { + case State::warmup: + if (isCloseEnoughForMeasurements(elapsed)) { + // if elapsed is close enough, we can skip upscaling and go right to measurements + // still, we don't add the result to the measurements. + mState = State::measuring; + mNumIters = calcBestNumIters(elapsed, mNumIters); + } else { + // not close enough: switch to upscaling + mState = State::upscaling_runtime; + upscale(elapsed); + } + break; + + case State::upscaling_runtime: + if (isCloseEnoughForMeasurements(elapsed)) { + // if we are close enough, add measurement and switch to always measuring + mState = State::measuring; + mTotalElapsed += elapsed; + mTotalNumIters += mNumIters; + mResult.add(elapsed, mNumIters, pc); + mNumIters = calcBestNumIters(mTotalElapsed, mTotalNumIters); + } else { + upscale(elapsed); + } + break; + + case State::measuring: + // just add measurements - no questions asked. Even when runtime is low. But we can't ignore + // that fluctuation, or else we would bias the result + mTotalElapsed += elapsed; + mTotalNumIters += mNumIters; + mResult.add(elapsed, mNumIters, pc); + if (0 != mBench.epochIterations()) { + mNumIters = mBench.epochIterations(); + } else { + mNumIters = calcBestNumIters(mTotalElapsed, mTotalNumIters); + } + break; + + case State::endless: + mNumIters = (std::numeric_limits<uint64_t>::max)(); + break; + } + + if (static_cast<uint64_t>(mResult.size()) == mBench.epochs()) { + // we got all the results that we need, finish it + showResult(""); + mNumIters = 0; + } + + ANKERL_NANOBENCH_LOG(mBench.name() << ": " << detail::fmt::Number(20, 3, static_cast<double>(elapsed.count())) << " elapsed, " + << detail::fmt::Number(20, 3, static_cast<double>(mTargetRuntimePerEpoch.count())) + << " target. oldIters=" << oldIters << ", mNumIters=" << mNumIters + << ", mState=" << static_cast<int>(mState)); + } + + void showResult(std::string const& errorMessage) const { + ANKERL_NANOBENCH_LOG(errorMessage); + + if (mBench.output() != nullptr) { + // prepare column data /////// + std::vector<fmt::MarkDownColumn> columns; + + auto rMedian = mResult.median(Result::Measure::elapsed); + + if (mBench.relative()) { + double d = 100.0; + if (!mBench.results().empty()) { + d = rMedian <= 0.0 ? 0.0 : mBench.results().front().median(Result::Measure::elapsed) / rMedian * 100.0; + } + columns.emplace_back(11, 1, "relative", "%", d); + } + + if (mBench.complexityN() > 0) { + columns.emplace_back(14, 0, "complexityN", "", mBench.complexityN()); + } + + columns.emplace_back(22, 2, "ns/" + mBench.unit(), "", 1e9 * rMedian / mBench.batch()); + columns.emplace_back(22, 2, mBench.unit() + "/s", "", rMedian <= 0.0 ? 0.0 : mBench.batch() / rMedian); + + double rErrorMedian = mResult.medianAbsolutePercentError(Result::Measure::elapsed); + columns.emplace_back(10, 1, "err%", "%", rErrorMedian * 100.0); + + double rInsMedian = -1.0; + if (mResult.has(Result::Measure::instructions)) { + rInsMedian = mResult.median(Result::Measure::instructions); + columns.emplace_back(18, 2, "ins/" + mBench.unit(), "", rInsMedian / mBench.batch()); + } + + double rCycMedian = -1.0; + if (mResult.has(Result::Measure::cpucycles)) { + rCycMedian = mResult.median(Result::Measure::cpucycles); + columns.emplace_back(18, 2, "cyc/" + mBench.unit(), "", rCycMedian / mBench.batch()); + } + if (rInsMedian > 0.0 && rCycMedian > 0.0) { + columns.emplace_back(9, 3, "IPC", "", rCycMedian <= 0.0 ? 0.0 : rInsMedian / rCycMedian); + } + if (mResult.has(Result::Measure::branchinstructions)) { + double rBraMedian = mResult.median(Result::Measure::branchinstructions); + columns.emplace_back(17, 2, "bra/" + mBench.unit(), "", rBraMedian / mBench.batch()); + if (mResult.has(Result::Measure::branchmisses)) { + double p = 0.0; + if (rBraMedian >= 1e-9) { + p = 100.0 * mResult.median(Result::Measure::branchmisses) / rBraMedian; + } + columns.emplace_back(10, 1, "miss%", "%", p); + } + } + + columns.emplace_back(12, 2, "total", "", mResult.sum(Result::Measure::elapsed)); + + // write everything + auto& os = *mBench.output(); + + uint64_t hash = 0; + hash = hash_combine(fnv1a(mBench.unit()), hash); + hash = hash_combine(fnv1a(mBench.title()), hash); + hash = hash_combine(mBench.relative(), hash); + hash = hash_combine(mBench.performanceCounters(), hash); + + if (hash != singletonHeaderHash()) { + singletonHeaderHash() = hash; + + // no result yet, print header + os << std::endl; + for (auto const& col : columns) { + os << col.title(); + } + os << "| " << mBench.title() << std::endl; + + for (auto const& col : columns) { + os << col.separator(); + } + os << "|:" << std::string(mBench.title().size() + 1U, '-') << std::endl; + } + + if (!errorMessage.empty()) { + for (auto const& col : columns) { + os << col.invalid(); + } + os << "| :boom: " << fmt::MarkDownCode(mBench.name()) << " (" << errorMessage << ')' << std::endl; + } else { + for (auto const& col : columns) { + os << col.value(); + } + os << "| "; + auto showUnstable = rErrorMedian >= 0.05; + if (showUnstable) { + os << ":wavy_dash: "; + } + os << fmt::MarkDownCode(mBench.name()); + if (showUnstable) { + auto avgIters = static_cast<double>(mTotalNumIters) / static_cast<double>(mBench.epochs()); + // NOLINTNEXTLINE(bugprone-incorrect-roundings) + auto suggestedIters = static_cast<uint64_t>(avgIters * 10 + 0.5); + + os << " (Unstable with ~" << detail::fmt::Number(1, 1, avgIters) + << " iters. Increase `minEpochIterations` to e.g. " << suggestedIters << ")"; + } + os << std::endl; + } + } + } + + ANKERL_NANOBENCH(NODISCARD) bool isCloseEnoughForMeasurements(std::chrono::nanoseconds elapsed) const noexcept { + return elapsed * 3 >= mTargetRuntimePerEpoch * 2; + } + + uint64_t mNumIters = 1; + Bench const& mBench; + std::chrono::nanoseconds mTargetRuntimePerEpoch{}; + Result mResult; + Rng mRng{123}; + std::chrono::nanoseconds mTotalElapsed{}; + uint64_t mTotalNumIters = 0; + + State mState = State::upscaling_runtime; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +IterationLogic::IterationLogic(Bench const& bench) noexcept + : mPimpl(new Impl(bench)) {} + +IterationLogic::~IterationLogic() { + if (mPimpl) { + delete mPimpl; + } +} + +uint64_t IterationLogic::numIters() const noexcept { + ANKERL_NANOBENCH_LOG(mPimpl->mBench.name() << ": mNumIters=" << mPimpl->mNumIters); + return mPimpl->mNumIters; +} + +void IterationLogic::add(std::chrono::nanoseconds elapsed, PerformanceCounters const& pc) noexcept { + mPimpl->add(elapsed, pc); +} + +void IterationLogic::moveResultTo(std::vector<Result>& results) noexcept { + results.emplace_back(std::move(mPimpl->mResult)); +} + +# if ANKERL_NANOBENCH(PERF_COUNTERS) + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class LinuxPerformanceCounters { +public: + struct Target { + Target(uint64_t* targetValue_, bool correctMeasuringOverhead_, bool correctLoopOverhead_) + : targetValue(targetValue_) + , correctMeasuringOverhead(correctMeasuringOverhead_) + , correctLoopOverhead(correctLoopOverhead_) {} + + uint64_t* targetValue{}; + bool correctMeasuringOverhead{}; + bool correctLoopOverhead{}; + }; + + ~LinuxPerformanceCounters(); + + // quick operation + inline void start() {} + + inline void stop() {} + + bool monitor(perf_sw_ids swId, Target target); + bool monitor(perf_hw_id hwId, Target target); + + bool hasError() const noexcept { + return mHasError; + } + + // Just reading data is faster than enable & disabling. + // we subtract data ourselves. + inline void beginMeasure() { + if (mHasError) { + return; + } + + // NOLINTNEXTLINE(hicpp-signed-bitwise) + mHasError = -1 == ioctl(mFd, PERF_EVENT_IOC_RESET, PERF_IOC_FLAG_GROUP); + if (mHasError) { + return; + } + + // NOLINTNEXTLINE(hicpp-signed-bitwise) + mHasError = -1 == ioctl(mFd, PERF_EVENT_IOC_ENABLE, PERF_IOC_FLAG_GROUP); + } + + inline void endMeasure() { + if (mHasError) { + return; + } + + // NOLINTNEXTLINE(hicpp-signed-bitwise) + mHasError = (-1 == ioctl(mFd, PERF_EVENT_IOC_DISABLE, PERF_IOC_FLAG_GROUP)); + if (mHasError) { + return; + } + + auto const numBytes = sizeof(uint64_t) * mCounters.size(); + auto ret = read(mFd, mCounters.data(), numBytes); + mHasError = ret != static_cast<ssize_t>(numBytes); + } + + void updateResults(uint64_t numIters); + + // rounded integer division + template <typename T> + static inline T divRounded(T a, T divisor) { + return (a + divisor / 2) / divisor; + } + + template <typename Op> + ANKERL_NANOBENCH_NO_SANITIZE("integer") + void calibrate(Op&& op) { + // clear current calibration data, + for (auto& v : mCalibratedOverhead) { + v = UINT64_C(0); + } + + // create new calibration data + auto newCalibration = mCalibratedOverhead; + for (auto& v : newCalibration) { + v = (std::numeric_limits<uint64_t>::max)(); + } + for (size_t iter = 0; iter < 100; ++iter) { + beginMeasure(); + op(); + endMeasure(); + if (mHasError) { + return; + } + + for (size_t i = 0; i < newCalibration.size(); ++i) { + auto diff = mCounters[i]; + if (newCalibration[i] > diff) { + newCalibration[i] = diff; + } + } + } + + mCalibratedOverhead = std::move(newCalibration); + + { + // calibrate loop overhead. For branches & instructions this makes sense, not so much for everything else like cycles. + // marsaglia's xorshift: mov, sal/shr, xor. Times 3. + // This has the nice property that the compiler doesn't seem to be able to optimize multiple calls any further. + // see https://godbolt.org/z/49RVQ5 + uint64_t const numIters = 100000U + (std::random_device{}() & 3); + uint64_t n = numIters; + uint32_t x = 1234567; + auto fn = [&]() { + x ^= x << 13; + x ^= x >> 17; + x ^= x << 5; + }; + + beginMeasure(); + while (n-- > 0) { + fn(); + } + endMeasure(); + detail::doNotOptimizeAway(x); + auto measure1 = mCounters; + + n = numIters; + beginMeasure(); + while (n-- > 0) { + // we now run *twice* so we can easily calculate the overhead + fn(); + fn(); + } + endMeasure(); + detail::doNotOptimizeAway(x); + auto measure2 = mCounters; + + for (size_t i = 0; i < mCounters.size(); ++i) { + // factor 2 because we have two instructions per loop + auto m1 = measure1[i] > mCalibratedOverhead[i] ? measure1[i] - mCalibratedOverhead[i] : 0; + auto m2 = measure2[i] > mCalibratedOverhead[i] ? measure2[i] - mCalibratedOverhead[i] : 0; + auto overhead = m1 * 2 > m2 ? m1 * 2 - m2 : 0; + + mLoopOverhead[i] = divRounded(overhead, numIters); + } + } + } + +private: + bool monitor(uint32_t type, uint64_t eventid, Target target); + + std::map<uint64_t, Target> mIdToTarget{}; + + // start with minimum size of 3 for read_format + std::vector<uint64_t> mCounters{3}; + std::vector<uint64_t> mCalibratedOverhead{3}; + std::vector<uint64_t> mLoopOverhead{3}; + + uint64_t mTimeEnabledNanos = 0; + uint64_t mTimeRunningNanos = 0; + int mFd = -1; + bool mHasError = false; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +LinuxPerformanceCounters::~LinuxPerformanceCounters() { + if (-1 != mFd) { + close(mFd); + } +} + +bool LinuxPerformanceCounters::monitor(perf_sw_ids swId, LinuxPerformanceCounters::Target target) { + return monitor(PERF_TYPE_SOFTWARE, swId, target); +} + +bool LinuxPerformanceCounters::monitor(perf_hw_id hwId, LinuxPerformanceCounters::Target target) { + return monitor(PERF_TYPE_HARDWARE, hwId, target); +} + +// overflow is ok, it's checked +ANKERL_NANOBENCH_NO_SANITIZE("integer") +void LinuxPerformanceCounters::updateResults(uint64_t numIters) { + // clear old data + for (auto& id_value : mIdToTarget) { + *id_value.second.targetValue = UINT64_C(0); + } + + if (mHasError) { + return; + } + + mTimeEnabledNanos = mCounters[1] - mCalibratedOverhead[1]; + mTimeRunningNanos = mCounters[2] - mCalibratedOverhead[2]; + + for (uint64_t i = 0; i < mCounters[0]; ++i) { + auto idx = static_cast<size_t>(3 + i * 2 + 0); + auto id = mCounters[idx + 1U]; + + auto it = mIdToTarget.find(id); + if (it != mIdToTarget.end()) { + + auto& tgt = it->second; + *tgt.targetValue = mCounters[idx]; + if (tgt.correctMeasuringOverhead) { + if (*tgt.targetValue >= mCalibratedOverhead[idx]) { + *tgt.targetValue -= mCalibratedOverhead[idx]; + } else { + *tgt.targetValue = 0U; + } + } + if (tgt.correctLoopOverhead) { + auto correctionVal = mLoopOverhead[idx] * numIters; + if (*tgt.targetValue >= correctionVal) { + *tgt.targetValue -= correctionVal; + } else { + *tgt.targetValue = 0U; + } + } + } + } +} + +bool LinuxPerformanceCounters::monitor(uint32_t type, uint64_t eventid, Target target) { + *target.targetValue = (std::numeric_limits<uint64_t>::max)(); + if (mHasError) { + return false; + } + + auto pea = perf_event_attr(); + std::memset(&pea, 0, sizeof(perf_event_attr)); + pea.type = type; + pea.size = sizeof(perf_event_attr); + pea.config = eventid; + pea.disabled = 1; // start counter as disabled + pea.exclude_kernel = 1; + pea.exclude_hv = 1; + + // NOLINTNEXTLINE(hicpp-signed-bitwise) + pea.read_format = PERF_FORMAT_GROUP | PERF_FORMAT_ID | PERF_FORMAT_TOTAL_TIME_ENABLED | PERF_FORMAT_TOTAL_TIME_RUNNING; + + const int pid = 0; // the current process + const int cpu = -1; // all CPUs +# if defined(PERF_FLAG_FD_CLOEXEC) // since Linux 3.14 + const unsigned long flags = PERF_FLAG_FD_CLOEXEC; +# else + const unsigned long flags = 0; +# endif + + auto fd = static_cast<int>(syscall(__NR_perf_event_open, &pea, pid, cpu, mFd, flags)); + if (-1 == fd) { + return false; + } + if (-1 == mFd) { + // first call: set to fd, and use this from now on + mFd = fd; + } + uint64_t id = 0; + // NOLINTNEXTLINE(hicpp-signed-bitwise) + if (-1 == ioctl(fd, PERF_EVENT_IOC_ID, &id)) { + // couldn't get id + return false; + } + + // insert into map, rely on the fact that map's references are constant. + mIdToTarget.emplace(id, target); + + // prepare readformat with the correct size (after the insert) + auto size = 3 + 2 * mIdToTarget.size(); + mCounters.resize(size); + mCalibratedOverhead.resize(size); + mLoopOverhead.resize(size); + + return true; +} + +PerformanceCounters::PerformanceCounters() + : mPc(new LinuxPerformanceCounters()) + , mVal() + , mHas() { + + mHas.pageFaults = mPc->monitor(PERF_COUNT_SW_PAGE_FAULTS, LinuxPerformanceCounters::Target(&mVal.pageFaults, true, false)); + mHas.cpuCycles = mPc->monitor(PERF_COUNT_HW_REF_CPU_CYCLES, LinuxPerformanceCounters::Target(&mVal.cpuCycles, true, false)); + mHas.contextSwitches = + mPc->monitor(PERF_COUNT_SW_CONTEXT_SWITCHES, LinuxPerformanceCounters::Target(&mVal.contextSwitches, true, false)); + mHas.instructions = mPc->monitor(PERF_COUNT_HW_INSTRUCTIONS, LinuxPerformanceCounters::Target(&mVal.instructions, true, true)); + mHas.branchInstructions = + mPc->monitor(PERF_COUNT_HW_BRANCH_INSTRUCTIONS, LinuxPerformanceCounters::Target(&mVal.branchInstructions, true, false)); + mHas.branchMisses = mPc->monitor(PERF_COUNT_HW_BRANCH_MISSES, LinuxPerformanceCounters::Target(&mVal.branchMisses, true, false)); + // mHas.branchMisses = false; + + mPc->start(); + mPc->calibrate([] { + auto before = ankerl::nanobench::Clock::now(); + auto after = ankerl::nanobench::Clock::now(); + (void)before; + (void)after; + }); + + if (mPc->hasError()) { + // something failed, don't monitor anything. + mHas = PerfCountSet<bool>{}; + } +} + +PerformanceCounters::~PerformanceCounters() { + if (nullptr != mPc) { + delete mPc; + } +} + +void PerformanceCounters::beginMeasure() { + mPc->beginMeasure(); +} + +void PerformanceCounters::endMeasure() { + mPc->endMeasure(); +} + +void PerformanceCounters::updateResults(uint64_t numIters) { + mPc->updateResults(numIters); +} + +# else + +PerformanceCounters::PerformanceCounters() = default; +PerformanceCounters::~PerformanceCounters() = default; +void PerformanceCounters::beginMeasure() {} +void PerformanceCounters::endMeasure() {} +void PerformanceCounters::updateResults(uint64_t) {} + +# endif + +ANKERL_NANOBENCH(NODISCARD) PerfCountSet<uint64_t> const& PerformanceCounters::val() const noexcept { + return mVal; +} +ANKERL_NANOBENCH(NODISCARD) PerfCountSet<bool> const& PerformanceCounters::has() const noexcept { + return mHas; +} + +// formatting utilities +namespace fmt { + +// adds thousands separator to numbers +NumSep::NumSep(char sep) + : mSep(sep) {} + +char NumSep::do_thousands_sep() const { + return mSep; +} + +std::string NumSep::do_grouping() const { + return "\003"; +} + +// RAII to save & restore a stream's state +StreamStateRestorer::StreamStateRestorer(std::ostream& s) + : mStream(s) + , mLocale(s.getloc()) + , mPrecision(s.precision()) + , mWidth(s.width()) + , mFill(s.fill()) + , mFmtFlags(s.flags()) {} + +StreamStateRestorer::~StreamStateRestorer() { + restore(); +} + +// sets back all stream info that we remembered at construction +void StreamStateRestorer::restore() { + mStream.imbue(mLocale); + mStream.precision(mPrecision); + mStream.width(mWidth); + mStream.fill(mFill); + mStream.flags(mFmtFlags); +} + +Number::Number(int width, int precision, int64_t value) + : mWidth(width) + , mPrecision(precision) + , mValue(static_cast<double>(value)) {} + +Number::Number(int width, int precision, double value) + : mWidth(width) + , mPrecision(precision) + , mValue(value) {} + +std::ostream& Number::write(std::ostream& os) const { + StreamStateRestorer restorer(os); + os.imbue(std::locale(os.getloc(), new NumSep(','))); + os << std::setw(mWidth) << std::setprecision(mPrecision) << std::fixed << mValue; + return os; +} + +std::string Number::to_s() const { + std::stringstream ss; + write(ss); + return ss.str(); +} + +std::string to_s(uint64_t n) { + std::string str; + do { + str += static_cast<char>('0' + static_cast<char>(n % 10)); + n /= 10; + } while (n != 0); + std::reverse(str.begin(), str.end()); + return str; +} + +std::ostream& operator<<(std::ostream& os, Number const& n) { + return n.write(os); +} + +MarkDownColumn::MarkDownColumn(int w, int prec, std::string const& tit, std::string const& suff, double val) + : mWidth(w) + , mPrecision(prec) + , mTitle(tit) + , mSuffix(suff) + , mValue(val) {} + +std::string MarkDownColumn::title() const { + std::stringstream ss; + ss << '|' << std::setw(mWidth - 2) << std::right << mTitle << ' '; + return ss.str(); +} + +std::string MarkDownColumn::separator() const { + std::string sep(static_cast<size_t>(mWidth), '-'); + sep.front() = '|'; + sep.back() = ':'; + return sep; +} + +std::string MarkDownColumn::invalid() const { + std::string sep(static_cast<size_t>(mWidth), ' '); + sep.front() = '|'; + sep[sep.size() - 2] = '-'; + return sep; +} + +std::string MarkDownColumn::value() const { + std::stringstream ss; + auto width = mWidth - 2 - static_cast<int>(mSuffix.size()); + ss << '|' << Number(width, mPrecision, mValue) << mSuffix << ' '; + return ss.str(); +} + +// Formats any text as markdown code, escaping backticks. +MarkDownCode::MarkDownCode(std::string const& what) { + mWhat.reserve(what.size() + 2); + mWhat.push_back('`'); + for (char c : what) { + mWhat.push_back(c); + if ('`' == c) { + mWhat.push_back('`'); + } + } + mWhat.push_back('`'); +} + +std::ostream& MarkDownCode::write(std::ostream& os) const { + return os << mWhat; +} + +std::ostream& operator<<(std::ostream& os, MarkDownCode const& mdCode) { + return mdCode.write(os); +} +} // namespace fmt +} // namespace detail + +// provide implementation here so it's only generated once +Config::Config() = default; +Config::~Config() = default; +Config& Config::operator=(Config const&) = default; +Config& Config::operator=(Config&&) = default; +Config::Config(Config const&) = default; +Config::Config(Config&&) noexcept = default; + +// provide implementation here so it's only generated once +Result::~Result() = default; +Result& Result::operator=(Result const&) = default; +Result& Result::operator=(Result&&) = default; +Result::Result(Result const&) = default; +Result::Result(Result&&) noexcept = default; + +namespace detail { +template <typename T> +inline constexpr typename std::underlying_type<T>::type u(T val) noexcept { + return static_cast<typename std::underlying_type<T>::type>(val); +} +} // namespace detail + +// Result returned after a benchmark has finished. Can be used as a baseline for relative(). +Result::Result(Config const& benchmarkConfig) + : mConfig(benchmarkConfig) + , mNameToMeasurements{detail::u(Result::Measure::_size)} {} + +void Result::add(Clock::duration totalElapsed, uint64_t iters, detail::PerformanceCounters const& pc) { + using detail::d; + using detail::u; + + double dIters = d(iters); + mNameToMeasurements[u(Result::Measure::iterations)].push_back(dIters); + + mNameToMeasurements[u(Result::Measure::elapsed)].push_back(d(totalElapsed) / dIters); + if (pc.has().pageFaults) { + mNameToMeasurements[u(Result::Measure::pagefaults)].push_back(d(pc.val().pageFaults) / dIters); + } + if (pc.has().cpuCycles) { + mNameToMeasurements[u(Result::Measure::cpucycles)].push_back(d(pc.val().cpuCycles) / dIters); + } + if (pc.has().contextSwitches) { + mNameToMeasurements[u(Result::Measure::contextswitches)].push_back(d(pc.val().contextSwitches) / dIters); + } + if (pc.has().instructions) { + mNameToMeasurements[u(Result::Measure::instructions)].push_back(d(pc.val().instructions) / dIters); + } + if (pc.has().branchInstructions) { + double branchInstructions = 0.0; + // correcting branches: remove branch introduced by the while (...) loop for each iteration. + if (pc.val().branchInstructions > iters + 1U) { + branchInstructions = d(pc.val().branchInstructions - (iters + 1U)); + } + mNameToMeasurements[u(Result::Measure::branchinstructions)].push_back(branchInstructions / dIters); + + if (pc.has().branchMisses) { + // correcting branch misses + double branchMisses = d(pc.val().branchMisses); + if (branchMisses > branchInstructions) { + // can't have branch misses when there were branches... + branchMisses = branchInstructions; + } + + // assuming at least one missed branch for the loop + branchMisses -= 1.0; + if (branchMisses < 1.0) { + branchMisses = 1.0; + } + mNameToMeasurements[u(Result::Measure::branchmisses)].push_back(branchMisses / dIters); + } + } +} + +Config const& Result::config() const noexcept { + return mConfig; +} + +inline double calcMedian(std::vector<double>& data) { + if (data.empty()) { + return 0.0; + } + std::sort(data.begin(), data.end()); + + auto midIdx = data.size() / 2U; + if (1U == (data.size() & 1U)) { + return data[midIdx]; + } + return (data[midIdx - 1U] + data[midIdx]) / 2U; +} + +double Result::median(Measure m) const { + // create a copy so we can sort + auto data = mNameToMeasurements[detail::u(m)]; + return calcMedian(data); +} + +double Result::average(Measure m) const { + using detail::d; + auto const& data = mNameToMeasurements[detail::u(m)]; + if (data.empty()) { + return 0.0; + } + + // create a copy so we can sort + return sum(m) / d(data.size()); +} + +double Result::medianAbsolutePercentError(Measure m) const { + // create copy + auto data = mNameToMeasurements[detail::u(m)]; + + // calculates MdAPE which is the median of percentage error + // see https://www.spiderfinancial.com/support/documentation/numxl/reference-manual/forecasting-performance/mdape + auto med = calcMedian(data); + + // transform the data to absolute error + for (auto& x : data) { + x = (x - med) / x; + if (x < 0) { + x = -x; + } + } + return calcMedian(data); +} + +double Result::sum(Measure m) const noexcept { + auto const& data = mNameToMeasurements[detail::u(m)]; + return std::accumulate(data.begin(), data.end(), 0.0); +} + +double Result::sumProduct(Measure m1, Measure m2) const noexcept { + auto const& data1 = mNameToMeasurements[detail::u(m1)]; + auto const& data2 = mNameToMeasurements[detail::u(m2)]; + + if (data1.size() != data2.size()) { + return 0.0; + } + + double result = 0.0; + for (size_t i = 0, s = data1.size(); i != s; ++i) { + result += data1[i] * data2[i]; + } + return result; +} + +bool Result::has(Measure m) const noexcept { + return !mNameToMeasurements[detail::u(m)].empty(); +} + +double Result::get(size_t idx, Measure m) const { + auto const& data = mNameToMeasurements[detail::u(m)]; + return data.at(idx); +} + +bool Result::empty() const noexcept { + return 0U == size(); +} + +size_t Result::size() const noexcept { + auto const& data = mNameToMeasurements[detail::u(Measure::elapsed)]; + return data.size(); +} + +double Result::minimum(Measure m) const noexcept { + auto const& data = mNameToMeasurements[detail::u(m)]; + if (data.empty()) { + return 0.0; + } + + // here its save to assume that at least one element is there + return *std::min_element(data.begin(), data.end()); +} + +double Result::maximum(Measure m) const noexcept { + auto const& data = mNameToMeasurements[detail::u(m)]; + if (data.empty()) { + return 0.0; + } + + // here its save to assume that at least one element is there + return *std::max_element(data.begin(), data.end()); +} + +Result::Measure Result::fromString(std::string const& str) { + if (str == "elapsed") { + return Measure::elapsed; + } else if (str == "iterations") { + return Measure::iterations; + } else if (str == "pagefaults") { + return Measure::pagefaults; + } else if (str == "cpucycles") { + return Measure::cpucycles; + } else if (str == "contextswitches") { + return Measure::contextswitches; + } else if (str == "instructions") { + return Measure::instructions; + } else if (str == "branchinstructions") { + return Measure::branchinstructions; + } else if (str == "branchmisses") { + return Measure::branchmisses; + } else { + // not found, return _size + return Measure::_size; + } +} + +// Configuration of a microbenchmark. +Bench::Bench() { + mConfig.mOut = &std::cout; +} + +Bench::Bench(Bench&&) = default; +Bench& Bench::operator=(Bench&&) = default; +Bench::Bench(Bench const&) = default; +Bench& Bench::operator=(Bench const&) = default; +Bench::~Bench() noexcept = default; + +double Bench::batch() const noexcept { + return mConfig.mBatch; +} + +double Bench::complexityN() const noexcept { + return mConfig.mComplexityN; +} + +// Set a baseline to compare it to. 100% it is exactly as fast as the baseline, >100% means it is faster than the baseline, <100% +// means it is slower than the baseline. +Bench& Bench::relative(bool isRelativeEnabled) noexcept { + mConfig.mIsRelative = isRelativeEnabled; + return *this; +} +bool Bench::relative() const noexcept { + return mConfig.mIsRelative; +} + +Bench& Bench::performanceCounters(bool showPerformanceCounters) noexcept { + mConfig.mShowPerformanceCounters = showPerformanceCounters; + return *this; +} +bool Bench::performanceCounters() const noexcept { + return mConfig.mShowPerformanceCounters; +} + +// Operation unit. Defaults to "op", could be e.g. "byte" for string processing. +// If u differs from currently set unit, the stored results will be cleared. +// Use singular (byte, not bytes). +Bench& Bench::unit(char const* u) { + if (u != mConfig.mUnit) { + mResults.clear(); + } + mConfig.mUnit = u; + return *this; +} + +Bench& Bench::unit(std::string const& u) { + return unit(u.c_str()); +} + +std::string const& Bench::unit() const noexcept { + return mConfig.mUnit; +} + +// If benchmarkTitle differs from currently set title, the stored results will be cleared. +Bench& Bench::title(const char* benchmarkTitle) { + if (benchmarkTitle != mConfig.mBenchmarkTitle) { + mResults.clear(); + } + mConfig.mBenchmarkTitle = benchmarkTitle; + return *this; +} +Bench& Bench::title(std::string const& benchmarkTitle) { + if (benchmarkTitle != mConfig.mBenchmarkTitle) { + mResults.clear(); + } + mConfig.mBenchmarkTitle = benchmarkTitle; + return *this; +} + +std::string const& Bench::title() const noexcept { + return mConfig.mBenchmarkTitle; +} + +Bench& Bench::name(const char* benchmarkName) { + mConfig.mBenchmarkName = benchmarkName; + return *this; +} + +Bench& Bench::name(std::string const& benchmarkName) { + mConfig.mBenchmarkName = benchmarkName; + return *this; +} + +std::string const& Bench::name() const noexcept { + return mConfig.mBenchmarkName; +} + +// Number of epochs to evaluate. The reported result will be the median of evaluation of each epoch. +Bench& Bench::epochs(size_t numEpochs) noexcept { + mConfig.mNumEpochs = numEpochs; + return *this; +} +size_t Bench::epochs() const noexcept { + return mConfig.mNumEpochs; +} + +// Desired evaluation time is a multiple of clock resolution. Default is to be 1000 times above this measurement precision. +Bench& Bench::clockResolutionMultiple(size_t multiple) noexcept { + mConfig.mClockResolutionMultiple = multiple; + return *this; +} +size_t Bench::clockResolutionMultiple() const noexcept { + return mConfig.mClockResolutionMultiple; +} + +// Sets the maximum time each epoch should take. Default is 100ms. +Bench& Bench::maxEpochTime(std::chrono::nanoseconds t) noexcept { + mConfig.mMaxEpochTime = t; + return *this; +} +std::chrono::nanoseconds Bench::maxEpochTime() const noexcept { + return mConfig.mMaxEpochTime; +} + +// Sets the maximum time each epoch should take. Default is 100ms. +Bench& Bench::minEpochTime(std::chrono::nanoseconds t) noexcept { + mConfig.mMinEpochTime = t; + return *this; +} +std::chrono::nanoseconds Bench::minEpochTime() const noexcept { + return mConfig.mMinEpochTime; +} + +Bench& Bench::minEpochIterations(uint64_t numIters) noexcept { + mConfig.mMinEpochIterations = (numIters == 0) ? 1 : numIters; + return *this; +} +uint64_t Bench::minEpochIterations() const noexcept { + return mConfig.mMinEpochIterations; +} + +Bench& Bench::epochIterations(uint64_t numIters) noexcept { + mConfig.mEpochIterations = numIters; + return *this; +} +uint64_t Bench::epochIterations() const noexcept { + return mConfig.mEpochIterations; +} + +Bench& Bench::warmup(uint64_t numWarmupIters) noexcept { + mConfig.mWarmup = numWarmupIters; + return *this; +} +uint64_t Bench::warmup() const noexcept { + return mConfig.mWarmup; +} + +Bench& Bench::config(Config const& benchmarkConfig) { + mConfig = benchmarkConfig; + return *this; +} +Config const& Bench::config() const noexcept { + return mConfig; +} + +Bench& Bench::output(std::ostream* outstream) noexcept { + mConfig.mOut = outstream; + return *this; +} + +ANKERL_NANOBENCH(NODISCARD) std::ostream* Bench::output() const noexcept { + return mConfig.mOut; +} + +std::vector<Result> const& Bench::results() const noexcept { + return mResults; +} + +Bench& Bench::render(char const* templateContent, std::ostream& os) { + ::ankerl::nanobench::render(templateContent, *this, os); + return *this; +} + +std::vector<BigO> Bench::complexityBigO() const { + std::vector<BigO> bigOs; + auto rangeMeasure = BigO::collectRangeMeasure(mResults); + bigOs.emplace_back("O(1)", rangeMeasure, [](double) { + return 1.0; + }); + bigOs.emplace_back("O(n)", rangeMeasure, [](double n) { + return n; + }); + bigOs.emplace_back("O(log n)", rangeMeasure, [](double n) { + return std::log2(n); + }); + bigOs.emplace_back("O(n log n)", rangeMeasure, [](double n) { + return n * std::log2(n); + }); + bigOs.emplace_back("O(n^2)", rangeMeasure, [](double n) { + return n * n; + }); + bigOs.emplace_back("O(n^3)", rangeMeasure, [](double n) { + return n * n * n; + }); + std::sort(bigOs.begin(), bigOs.end()); + return bigOs; +} + +Rng::Rng() + : mX(0) + , mY(0) { + std::random_device rd; + std::uniform_int_distribution<uint64_t> dist; + do { + mX = dist(rd); + mY = dist(rd); + } while (mX == 0 && mY == 0); +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer") +uint64_t splitMix64(uint64_t& state) noexcept { + uint64_t z = (state += UINT64_C(0x9e3779b97f4a7c15)); + z = (z ^ (z >> 30U)) * UINT64_C(0xbf58476d1ce4e5b9); + z = (z ^ (z >> 27U)) * UINT64_C(0x94d049bb133111eb); + return z ^ (z >> 31U); +} + +// Seeded as described in romu paper (update april 2020) +Rng::Rng(uint64_t seed) noexcept + : mX(splitMix64(seed)) + , mY(splitMix64(seed)) { + for (size_t i = 0; i < 10; ++i) { + operator()(); + } +} + +// only internally used to copy the RNG. +Rng::Rng(uint64_t x, uint64_t y) noexcept + : mX(x) + , mY(y) {} + +Rng Rng::copy() const noexcept { + return Rng{mX, mY}; +} + +BigO::RangeMeasure BigO::collectRangeMeasure(std::vector<Result> const& results) { + BigO::RangeMeasure rangeMeasure; + for (auto const& result : results) { + if (result.config().mComplexityN > 0.0) { + rangeMeasure.emplace_back(result.config().mComplexityN, result.median(Result::Measure::elapsed)); + } + } + return rangeMeasure; +} + +BigO::BigO(std::string const& bigOName, RangeMeasure const& rangeMeasure) + : mName(bigOName) { + + // estimate the constant factor + double sumRangeMeasure = 0.0; + double sumRangeRange = 0.0; + + for (size_t i = 0; i < rangeMeasure.size(); ++i) { + sumRangeMeasure += rangeMeasure[i].first * rangeMeasure[i].second; + sumRangeRange += rangeMeasure[i].first * rangeMeasure[i].first; + } + mConstant = sumRangeMeasure / sumRangeRange; + + // calculate root mean square + double err = 0.0; + double sumMeasure = 0.0; + for (size_t i = 0; i < rangeMeasure.size(); ++i) { + auto diff = mConstant * rangeMeasure[i].first - rangeMeasure[i].second; + err += diff * diff; + + sumMeasure += rangeMeasure[i].second; + } + + auto n = static_cast<double>(rangeMeasure.size()); + auto mean = sumMeasure / n; + mNormalizedRootMeanSquare = std::sqrt(err / n) / mean; +} + +BigO::BigO(const char* bigOName, RangeMeasure const& rangeMeasure) + : BigO(std::string(bigOName), rangeMeasure) {} + +std::string const& BigO::name() const noexcept { + return mName; +} + +double BigO::constant() const noexcept { + return mConstant; +} + +double BigO::normalizedRootMeanSquare() const noexcept { + return mNormalizedRootMeanSquare; +} + +bool BigO::operator<(BigO const& other) const noexcept { + return std::tie(mNormalizedRootMeanSquare, mName) < std::tie(other.mNormalizedRootMeanSquare, other.mName); +} + +std::ostream& operator<<(std::ostream& os, BigO const& bigO) { + return os << bigO.constant() << " * " << bigO.name() << ", rms=" << bigO.normalizedRootMeanSquare(); +} + +std::ostream& operator<<(std::ostream& os, std::vector<ankerl::nanobench::BigO> const& bigOs) { + detail::fmt::StreamStateRestorer restorer(os); + os << std::endl << "| coefficient | err% | complexity" << std::endl << "|--------------:|-------:|------------" << std::endl; + for (auto const& bigO : bigOs) { + os << "|" << std::setw(14) << std::setprecision(7) << std::scientific << bigO.constant() << " "; + os << "|" << detail::fmt::Number(6, 1, bigO.normalizedRootMeanSquare() * 100.0) << "% "; + os << "| " << bigO.name(); + os << std::endl; + } + return os; +} + +} // namespace nanobench +} // namespace ankerl + +#endif // ANKERL_NANOBENCH_IMPLEMENT +#endif // ANKERL_NANOBENCH_H_INCLUDED diff --git a/src/bench/poly1305.cpp b/src/bench/poly1305.cpp index 02e5fecc0d..d8db99e7d4 100644 --- a/src/bench/poly1305.cpp +++ b/src/bench/poly1305.cpp @@ -11,30 +11,31 @@ static constexpr uint64_t BUFFER_SIZE_TINY = 64; static constexpr uint64_t BUFFER_SIZE_SMALL = 256; static constexpr uint64_t BUFFER_SIZE_LARGE = 1024*1024; -static void POLY1305(benchmark::State& state, size_t buffersize) +static void POLY1305(benchmark::Bench& bench, size_t buffersize) { std::vector<unsigned char> tag(POLY1305_TAGLEN, 0); std::vector<unsigned char> key(POLY1305_KEYLEN, 0); std::vector<unsigned char> in(buffersize, 0); - while (state.KeepRunning()) + bench.batch(in.size()).unit("byte").run([&] { poly1305_auth(tag.data(), in.data(), in.size(), key.data()); + }); } -static void POLY1305_64BYTES(benchmark::State& state) +static void POLY1305_64BYTES(benchmark::Bench& bench) { - POLY1305(state, BUFFER_SIZE_TINY); + POLY1305(bench, BUFFER_SIZE_TINY); } -static void POLY1305_256BYTES(benchmark::State& state) +static void POLY1305_256BYTES(benchmark::Bench& bench) { - POLY1305(state, BUFFER_SIZE_SMALL); + POLY1305(bench, BUFFER_SIZE_SMALL); } -static void POLY1305_1MB(benchmark::State& state) +static void POLY1305_1MB(benchmark::Bench& bench) { - POLY1305(state, BUFFER_SIZE_LARGE); + POLY1305(bench, BUFFER_SIZE_LARGE); } -BENCHMARK(POLY1305_64BYTES, 500000); -BENCHMARK(POLY1305_256BYTES, 250000); -BENCHMARK(POLY1305_1MB, 340); +BENCHMARK(POLY1305_64BYTES); +BENCHMARK(POLY1305_256BYTES); +BENCHMARK(POLY1305_1MB); diff --git a/src/bench/prevector.cpp b/src/bench/prevector.cpp index 42b351a72d..a2dbefa54a 100644 --- a/src/bench/prevector.cpp +++ b/src/bench/prevector.cpp @@ -30,51 +30,44 @@ static_assert(IS_TRIVIALLY_CONSTRUCTIBLE<trivial_t>::value, "expected trivial_t to be trivially constructible"); template <typename T> -static void PrevectorDestructor(benchmark::State& state) +static void PrevectorDestructor(benchmark::Bench& bench) { - while (state.KeepRunning()) { - for (auto x = 0; x < 1000; ++x) { - prevector<28, T> t0; - prevector<28, T> t1; - t0.resize(28); - t1.resize(29); - } - } + bench.batch(2).run([&] { + prevector<28, T> t0; + prevector<28, T> t1; + t0.resize(28); + t1.resize(29); + }); } template <typename T> -static void PrevectorClear(benchmark::State& state) +static void PrevectorClear(benchmark::Bench& bench) { - - while (state.KeepRunning()) { - for (auto x = 0; x < 1000; ++x) { - prevector<28, T> t0; - prevector<28, T> t1; - t0.resize(28); - t0.clear(); - t1.resize(29); - t1.clear(); - } - } + prevector<28, T> t0; + prevector<28, T> t1; + bench.batch(2).run([&] { + t0.resize(28); + t0.clear(); + t1.resize(29); + t1.clear(); + }); } template <typename T> -static void PrevectorResize(benchmark::State& state) +static void PrevectorResize(benchmark::Bench& bench) { - while (state.KeepRunning()) { - prevector<28, T> t0; - prevector<28, T> t1; - for (auto x = 0; x < 1000; ++x) { - t0.resize(28); - t0.resize(0); - t1.resize(29); - t1.resize(0); - } - } + prevector<28, T> t0; + prevector<28, T> t1; + bench.batch(4).run([&] { + t0.resize(28); + t0.resize(0); + t1.resize(29); + t1.resize(0); + }); } template <typename T> -static void PrevectorDeserialize(benchmark::State& state) +static void PrevectorDeserialize(benchmark::Bench& bench) { CDataStream s0(SER_NETWORK, 0); prevector<28, T> t0; @@ -86,26 +79,28 @@ static void PrevectorDeserialize(benchmark::State& state) for (auto x = 0; x < 101; ++x) { s0 << t0; } - while (state.KeepRunning()) { + bench.batch(1000).run([&] { prevector<28, T> t1; for (auto x = 0; x < 1000; ++x) { s0 >> t1; } s0.Init(SER_NETWORK, 0); - } + }); } -#define PREVECTOR_TEST(name, nontrivops, trivops) \ - static void Prevector ## name ## Nontrivial(benchmark::State& state) { \ - Prevector ## name<nontrivial_t>(state); \ - } \ - BENCHMARK(Prevector ## name ## Nontrivial, nontrivops); \ - static void Prevector ## name ## Trivial(benchmark::State& state) { \ - Prevector ## name<trivial_t>(state); \ - } \ - BENCHMARK(Prevector ## name ## Trivial, trivops); +#define PREVECTOR_TEST(name) \ + static void Prevector##name##Nontrivial(benchmark::Bench& bench) \ + { \ + Prevector##name<nontrivial_t>(bench); \ + } \ + BENCHMARK(Prevector##name##Nontrivial); \ + static void Prevector##name##Trivial(benchmark::Bench& bench) \ + { \ + Prevector##name<trivial_t>(bench); \ + } \ + BENCHMARK(Prevector##name##Trivial); -PREVECTOR_TEST(Clear, 28300, 88600) -PREVECTOR_TEST(Destructor, 28800, 88900) -PREVECTOR_TEST(Resize, 28900, 90300) -PREVECTOR_TEST(Deserialize, 6800, 52000) +PREVECTOR_TEST(Clear) +PREVECTOR_TEST(Destructor) +PREVECTOR_TEST(Resize) +PREVECTOR_TEST(Deserialize) diff --git a/src/bench/rollingbloom.cpp b/src/bench/rollingbloom.cpp index 6cdb4ff0a7..9b43951e6e 100644 --- a/src/bench/rollingbloom.cpp +++ b/src/bench/rollingbloom.cpp @@ -6,12 +6,12 @@ #include <bench/bench.h> #include <bloom.h> -static void RollingBloom(benchmark::State& state) +static void RollingBloom(benchmark::Bench& bench) { CRollingBloomFilter filter(120000, 0.000001); std::vector<unsigned char> data(32); uint32_t count = 0; - while (state.KeepRunning()) { + bench.run([&] { count++; data[0] = count; data[1] = count >> 8; @@ -24,16 +24,16 @@ static void RollingBloom(benchmark::State& state) data[2] = count >> 8; data[3] = count; filter.contains(data); - } + }); } -static void RollingBloomReset(benchmark::State& state) +static void RollingBloomReset(benchmark::Bench& bench) { CRollingBloomFilter filter(120000, 0.000001); - while (state.KeepRunning()) { + bench.run([&] { filter.reset(); - } + }); } -BENCHMARK(RollingBloom, 1500 * 1000); -BENCHMARK(RollingBloomReset, 20000); +BENCHMARK(RollingBloom); +BENCHMARK(RollingBloomReset); diff --git a/src/bench/rpc_blockchain.cpp b/src/bench/rpc_blockchain.cpp index 511573abac..4b45264a3c 100644 --- a/src/bench/rpc_blockchain.cpp +++ b/src/bench/rpc_blockchain.cpp @@ -11,7 +11,8 @@ #include <univalue.h> -static void BlockToJsonVerbose(benchmark::State& state) { +static void BlockToJsonVerbose(benchmark::Bench& bench) +{ CDataStream stream(benchmark::data::block413567, SER_NETWORK, PROTOCOL_VERSION); char a = '\0'; stream.write(&a, 1); // Prevent compaction @@ -24,9 +25,9 @@ static void BlockToJsonVerbose(benchmark::State& state) { blockindex.phashBlock = &blockHash; blockindex.nBits = 403014710; - while (state.KeepRunning()) { + bench.run([&] { (void)blockToJSON(block, &blockindex, &blockindex, /*verbose*/ true); - } + }); } -BENCHMARK(BlockToJsonVerbose, 10); +BENCHMARK(BlockToJsonVerbose); diff --git a/src/bench/rpc_mempool.cpp b/src/bench/rpc_mempool.cpp index bf63cccf09..1ff41765cf 100644 --- a/src/bench/rpc_mempool.cpp +++ b/src/bench/rpc_mempool.cpp @@ -15,7 +15,7 @@ static void AddTx(const CTransactionRef& tx, const CAmount& fee, CTxMemPool& poo pool.addUnchecked(CTxMemPoolEntry(tx, fee, /* time */ 0, /* height */ 1, /* spendsCoinbase */ false, /* sigOpCost */ 4, lp)); } -static void RpcMempool(benchmark::State& state) +static void RpcMempool(benchmark::Bench& bench) { CTxMemPool pool; LOCK2(cs_main, pool.cs); @@ -32,9 +32,9 @@ static void RpcMempool(benchmark::State& state) AddTx(tx_r, /* fee */ i, pool); } - while (state.KeepRunning()) { + bench.run([&] { (void)MempoolToJSON(pool, /*verbose*/ true); - } + }); } -BENCHMARK(RpcMempool, 40); +BENCHMARK(RpcMempool); diff --git a/src/bench/util_time.cpp b/src/bench/util_time.cpp index 72d97354aa..fad179eb87 100644 --- a/src/bench/util_time.cpp +++ b/src/bench/util_time.cpp @@ -6,37 +6,37 @@ #include <util/time.h> -static void BenchTimeDeprecated(benchmark::State& state) +static void BenchTimeDeprecated(benchmark::Bench& bench) { - while (state.KeepRunning()) { + bench.run([&] { (void)GetTime(); - } + }); } -static void BenchTimeMock(benchmark::State& state) +static void BenchTimeMock(benchmark::Bench& bench) { SetMockTime(111); - while (state.KeepRunning()) { + bench.run([&] { (void)GetTime<std::chrono::seconds>(); - } + }); SetMockTime(0); } -static void BenchTimeMillis(benchmark::State& state) +static void BenchTimeMillis(benchmark::Bench& bench) { - while (state.KeepRunning()) { + bench.run([&] { (void)GetTime<std::chrono::milliseconds>(); - } + }); } -static void BenchTimeMillisSys(benchmark::State& state) +static void BenchTimeMillisSys(benchmark::Bench& bench) { - while (state.KeepRunning()) { + bench.run([&] { (void)GetTimeMillis(); - } + }); } -BENCHMARK(BenchTimeDeprecated, 100000000); -BENCHMARK(BenchTimeMillis, 6000000); -BENCHMARK(BenchTimeMillisSys, 6000000); -BENCHMARK(BenchTimeMock, 300000000); +BENCHMARK(BenchTimeDeprecated); +BENCHMARK(BenchTimeMillis); +BENCHMARK(BenchTimeMillisSys); +BENCHMARK(BenchTimeMock); diff --git a/src/bench/verify_script.cpp b/src/bench/verify_script.cpp index 14bca5f7d1..9af0b502eb 100644 --- a/src/bench/verify_script.cpp +++ b/src/bench/verify_script.cpp @@ -16,7 +16,7 @@ // Microbenchmark for verification of a basic P2WPKH script. Can be easily // modified to measure performance of other types of scripts. -static void VerifyScriptBench(benchmark::State& state) +static void VerifyScriptBench(benchmark::Bench& bench) { const ECCVerifyHandle verify_handle; ECC_Start(); @@ -34,7 +34,7 @@ static void VerifyScriptBench(benchmark::State& state) key.Set(vchKey.begin(), vchKey.end(), false); CPubKey pubkey = key.GetPubKey(); uint160 pubkeyHash; - CHash160().Write(pubkey.begin(), pubkey.size()).Finalize(pubkeyHash.begin()); + CHash160().Write(pubkey).Finalize(pubkeyHash); // Script. CScript scriptPubKey = CScript() << witnessversion << ToByteVector(pubkeyHash); @@ -49,7 +49,7 @@ static void VerifyScriptBench(benchmark::State& state) witness.stack.push_back(ToByteVector(pubkey)); // Benchmark. - while (state.KeepRunning()) { + bench.run([&] { ScriptError err; bool success = VerifyScript( txSpend.vin[0].scriptSig, @@ -71,11 +71,12 @@ static void VerifyScriptBench(benchmark::State& state) (const unsigned char*)stream.data(), stream.size(), 0, flags, nullptr); assert(csuccess == 1); #endif - } + }); ECC_Stop(); } -static void VerifyNestedIfScript(benchmark::State& state) { +static void VerifyNestedIfScript(benchmark::Bench& bench) +{ std::vector<std::vector<unsigned char>> stack; CScript script; for (int i = 0; i < 100; ++i) { @@ -87,15 +88,13 @@ static void VerifyNestedIfScript(benchmark::State& state) { for (int i = 0; i < 100; ++i) { script << OP_ENDIF; } - while (state.KeepRunning()) { + bench.run([&] { auto stack_copy = stack; ScriptError error; bool ret = EvalScript(stack_copy, script, 0, BaseSignatureChecker(), SigVersion::BASE, &error); assert(ret); - } + }); } - -BENCHMARK(VerifyScriptBench, 6300); - -BENCHMARK(VerifyNestedIfScript, 100); +BENCHMARK(VerifyScriptBench); +BENCHMARK(VerifyNestedIfScript); diff --git a/src/bench/wallet_balance.cpp b/src/bench/wallet_balance.cpp index 810c344ab5..e16182b48e 100644 --- a/src/bench/wallet_balance.cpp +++ b/src/bench/wallet_balance.cpp @@ -12,7 +12,7 @@ #include <validationinterface.h> #include <wallet/wallet.h> -static void WalletBalance(benchmark::State& state, const bool set_dirty, const bool add_watchonly, const bool add_mine) +static void WalletBalance(benchmark::Bench& bench, const bool set_dirty, const bool add_watchonly, const bool add_mine) { TestingSetup test_setup{ CBaseChainParams::REGTEST, @@ -26,7 +26,7 @@ static void WalletBalance(benchmark::State& state, const bool set_dirty, const b NodeContext node; std::unique_ptr<interfaces::Chain> chain = interfaces::MakeChain(node); - CWallet wallet{chain.get(), WalletLocation(), WalletDatabase::CreateMock()}; + CWallet wallet{chain.get(), WalletLocation(), CreateMockWalletDatabase()}; { wallet.SetupLegacyScriptPubKeyMan(); bool first_run; @@ -45,20 +45,20 @@ static void WalletBalance(benchmark::State& state, const bool set_dirty, const b auto bal = wallet.GetBalance(); // Cache - while (state.KeepRunning()) { + bench.run([&] { if (set_dirty) wallet.MarkDirty(); bal = wallet.GetBalance(); if (add_mine) assert(bal.m_mine_trusted > 0); if (add_watchonly) assert(bal.m_watchonly_trusted > 0); - } + }); } -static void WalletBalanceDirty(benchmark::State& state) { WalletBalance(state, /* set_dirty */ true, /* add_watchonly */ true, /* add_mine */ true); } -static void WalletBalanceClean(benchmark::State& state) { WalletBalance(state, /* set_dirty */ false, /* add_watchonly */ true, /* add_mine */ true); } -static void WalletBalanceMine(benchmark::State& state) { WalletBalance(state, /* set_dirty */ false, /* add_watchonly */ false, /* add_mine */ true); } -static void WalletBalanceWatch(benchmark::State& state) { WalletBalance(state, /* set_dirty */ false, /* add_watchonly */ true, /* add_mine */ false); } +static void WalletBalanceDirty(benchmark::Bench& bench) { WalletBalance(bench, /* set_dirty */ true, /* add_watchonly */ true, /* add_mine */ true); } +static void WalletBalanceClean(benchmark::Bench& bench) { WalletBalance(bench, /* set_dirty */ false, /* add_watchonly */ true, /* add_mine */ true); } +static void WalletBalanceMine(benchmark::Bench& bench) { WalletBalance(bench, /* set_dirty */ false, /* add_watchonly */ false, /* add_mine */ true); } +static void WalletBalanceWatch(benchmark::Bench& bench) { WalletBalance(bench, /* set_dirty */ false, /* add_watchonly */ true, /* add_mine */ false); } -BENCHMARK(WalletBalanceDirty, 2500); -BENCHMARK(WalletBalanceClean, 8000); -BENCHMARK(WalletBalanceMine, 16000); -BENCHMARK(WalletBalanceWatch, 8000); +BENCHMARK(WalletBalanceDirty); +BENCHMARK(WalletBalanceClean); +BENCHMARK(WalletBalanceMine); +BENCHMARK(WalletBalanceWatch); diff --git a/src/bitcoin-cli.cpp b/src/bitcoin-cli.cpp index 8d85789b4e..cf52b710cb 100644 --- a/src/bitcoin-cli.cpp +++ b/src/bitcoin-cli.cpp @@ -11,6 +11,7 @@ #include <clientversion.h> #include <optional.h> #include <rpc/client.h> +#include <rpc/mining.h> #include <rpc/protocol.h> #include <rpc/request.h> #include <util/strencodings.h> @@ -39,31 +40,35 @@ static const int DEFAULT_HTTP_CLIENT_TIMEOUT=900; static const bool DEFAULT_NAMED=false; static const int CONTINUE_EXECUTION=-1; -static void SetupCliArgs() +/** Default number of blocks to generate for RPC generatetoaddress. */ +static const std::string DEFAULT_NBLOCKS = "1"; + +static void SetupCliArgs(ArgsManager& argsman) { - SetupHelpOptions(gArgs); + SetupHelpOptions(argsman); const auto defaultBaseParams = CreateBaseChainParams(CBaseChainParams::MAIN); const auto testnetBaseParams = CreateBaseChainParams(CBaseChainParams::TESTNET); const auto regtestBaseParams = CreateBaseChainParams(CBaseChainParams::REGTEST); - gArgs.AddArg("-version", "Print version and exit", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-conf=<file>", strprintf("Specify configuration file. Relative paths will be prefixed by datadir location. (default: %s)", BITCOIN_CONF_FILENAME), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-datadir=<dir>", "Specify data directory", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-getinfo", "Get general information from the remote server. Note that unlike server-side RPC calls, the results of -getinfo is the result of multiple non-atomic requests. Some entries in the result may represent results from different states (e.g. wallet balance may be as of a different block from the chain state reported)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - SetupChainParamsBaseOptions(); - gArgs.AddArg("-named", strprintf("Pass named instead of positional arguments (default: %s)", DEFAULT_NAMED), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-rpcclienttimeout=<n>", strprintf("Timeout in seconds during HTTP requests, or 0 for no timeout. (default: %d)", DEFAULT_HTTP_CLIENT_TIMEOUT), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-rpcconnect=<ip>", strprintf("Send commands to node running on <ip> (default: %s)", DEFAULT_RPCCONNECT), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-rpccookiefile=<loc>", "Location of the auth cookie. Relative paths will be prefixed by a net-specific datadir location. (default: data dir)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-rpcpassword=<pw>", "Password for JSON-RPC connections", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-rpcport=<port>", strprintf("Connect to JSON-RPC on <port> (default: %u, testnet: %u, regtest: %u)", defaultBaseParams->RPCPort(), testnetBaseParams->RPCPort(), regtestBaseParams->RPCPort()), ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::OPTIONS); - gArgs.AddArg("-rpcuser=<user>", "Username for JSON-RPC connections", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-rpcwait", "Wait for RPC server to start", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-rpcwallet=<walletname>", "Send RPC for non-default wallet on RPC server (needs to exactly match corresponding -wallet option passed to bitcoind). This changes the RPC endpoint used, e.g. http://127.0.0.1:8332/wallet/<walletname>", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-stdin", "Read extra arguments from standard input, one per line until EOF/Ctrl-D (recommended for sensitive information such as passphrases). When combined with -stdinrpcpass, the first line from standard input is used for the RPC password.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-stdinrpcpass", "Read RPC password from standard input as a single line. When combined with -stdin, the first line from standard input is used for the RPC password. When combined with -stdinwalletpassphrase, -stdinrpcpass consumes the first line, and -stdinwalletpassphrase consumes the second.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-stdinwalletpassphrase", "Read wallet passphrase from standard input as a single line. When combined with -stdin, the first line from standard input is used for the wallet passphrase.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-version", "Print version and exit", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-conf=<file>", strprintf("Specify configuration file. Relative paths will be prefixed by datadir location. (default: %s)", BITCOIN_CONF_FILENAME), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-datadir=<dir>", "Specify data directory", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-generate", strprintf("Generate blocks immediately, equivalent to RPC generatenewaddress followed by RPC generatetoaddress. Optional positional integer arguments are number of blocks to generate (default: %s) and maximum iterations to try (default: %s), equivalent to RPC generatetoaddress nblocks and maxtries arguments. Example: bitcoin-cli -generate 4 1000", DEFAULT_NBLOCKS, DEFAULT_MAX_TRIES), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-getinfo", "Get general information from the remote server. Note that unlike server-side RPC calls, the results of -getinfo is the result of multiple non-atomic requests. Some entries in the result may represent results from different states (e.g. wallet balance may be as of a different block from the chain state reported)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + SetupChainParamsBaseOptions(argsman); + argsman.AddArg("-named", strprintf("Pass named instead of positional arguments (default: %s)", DEFAULT_NAMED), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-rpcclienttimeout=<n>", strprintf("Timeout in seconds during HTTP requests, or 0 for no timeout. (default: %d)", DEFAULT_HTTP_CLIENT_TIMEOUT), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-rpcconnect=<ip>", strprintf("Send commands to node running on <ip> (default: %s)", DEFAULT_RPCCONNECT), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-rpccookiefile=<loc>", "Location of the auth cookie. Relative paths will be prefixed by a net-specific datadir location. (default: data dir)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-rpcpassword=<pw>", "Password for JSON-RPC connections", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-rpcport=<port>", strprintf("Connect to JSON-RPC on <port> (default: %u, testnet: %u, regtest: %u)", defaultBaseParams->RPCPort(), testnetBaseParams->RPCPort(), regtestBaseParams->RPCPort()), ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::OPTIONS); + argsman.AddArg("-rpcuser=<user>", "Username for JSON-RPC connections", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-rpcwait", "Wait for RPC server to start", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-rpcwallet=<walletname>", "Send RPC for non-default wallet on RPC server (needs to exactly match corresponding -wallet option passed to bitcoind). This changes the RPC endpoint used, e.g. http://127.0.0.1:8332/wallet/<walletname>", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-stdin", "Read extra arguments from standard input, one per line until EOF/Ctrl-D (recommended for sensitive information such as passphrases). When combined with -stdinrpcpass, the first line from standard input is used for the RPC password.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-stdinrpcpass", "Read RPC password from standard input as a single line. When combined with -stdin, the first line from standard input is used for the RPC password. When combined with -stdinwalletpassphrase, -stdinrpcpass consumes the first line, and -stdinwalletpassphrase consumes the second.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-stdinwalletpassphrase", "Read wallet passphrase from standard input as a single line. When combined with -stdin, the first line from standard input is used for the wallet passphrase.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); } /** libevent event log callback */ @@ -106,7 +111,7 @@ static int AppInitRPC(int argc, char* argv[]) // // Parameters // - SetupCliArgs(); + SetupCliArgs(gArgs); std::string error; if (!gArgs.ParseParameters(argc, argv, error)) { tfm::format(std::cerr, "Error parsing command line arguments: %s\n", error); @@ -286,6 +291,28 @@ public: } }; +/** Process RPC generatetoaddress request. */ +class GenerateToAddressRequestHandler : public BaseRequestHandler +{ +public: + UniValue PrepareRequest(const std::string& method, const std::vector<std::string>& args) override + { + address_str = args.at(1); + UniValue params{RPCConvertValues("generatetoaddress", args)}; + return JSONRPCRequestObj("generatetoaddress", params, 1); + } + + UniValue ProcessReply(const UniValue &reply) override + { + UniValue result(UniValue::VOBJ); + result.pushKV("address", address_str); + result.pushKV("blocks", reply.get_obj()["result"]); + return JSONRPCReplyObj(result, NullUniValue, 1); + } +protected: + std::string address_str; +}; + /** Process default single requests */ class DefaultRequestHandler: public BaseRequestHandler { public: @@ -453,6 +480,34 @@ static UniValue ConnectAndCallRPC(BaseRequestHandler* rh, const std::string& str return response; } +/** Parse UniValue result to update the message to print to std::cout. */ +static void ParseResult(const UniValue& result, std::string& strPrint) +{ + if (result.isNull()) return; + strPrint = result.isStr() ? result.get_str() : result.write(2); +} + +/** Parse UniValue error to update the message to print to std::cerr and the code to return. */ +static void ParseError(const UniValue& error, std::string& strPrint, int& nRet) +{ + if (error.isObject()) { + const UniValue& err_code = find_value(error, "code"); + const UniValue& err_msg = find_value(error, "message"); + if (!err_code.isNull()) { + strPrint = "error code: " + err_code.getValStr() + "\n"; + } + if (err_msg.isStr()) { + strPrint += ("error message:\n" + err_msg.get_str()); + } + if (err_code.isNum() && err_code.get_int() == RPC_WALLET_NOT_SPECIFIED) { + strPrint += "\nTry adding \"-rpcwallet=<filename>\" option to bitcoin-cli command line."; + } + } else { + strPrint = "error: " + error.write(); + } + nRet = abs(error["code"].get_int()); +} + /** * GetWalletBalances calls listwallets; if more than one wallet is loaded, it then * fetches mine.trusted balances for each loaded wallet and pushes them to `result`. @@ -461,8 +516,8 @@ static UniValue ConnectAndCallRPC(BaseRequestHandler* rh, const std::string& str */ static void GetWalletBalances(UniValue& result) { - std::unique_ptr<BaseRequestHandler> rh{MakeUnique<DefaultRequestHandler>()}; - const UniValue listwallets = ConnectAndCallRPC(rh.get(), "listwallets", /* args=*/{}); + DefaultRequestHandler rh; + const UniValue listwallets = ConnectAndCallRPC(&rh, "listwallets", /* args=*/{}); if (!find_value(listwallets, "error").isNull()) return; const UniValue& wallets = find_value(listwallets, "result"); if (wallets.size() <= 1) return; @@ -470,13 +525,41 @@ static void GetWalletBalances(UniValue& result) UniValue balances(UniValue::VOBJ); for (const UniValue& wallet : wallets.getValues()) { const std::string wallet_name = wallet.get_str(); - const UniValue getbalances = ConnectAndCallRPC(rh.get(), "getbalances", /* args=*/{}, wallet_name); + const UniValue getbalances = ConnectAndCallRPC(&rh, "getbalances", /* args=*/{}, wallet_name); const UniValue& balance = find_value(getbalances, "result")["mine"]["trusted"]; balances.pushKV(wallet_name, balance); } result.pushKV("balances", balances); } +/** + * Call RPC getnewaddress. + * @returns getnewaddress response as a UniValue object. + */ +static UniValue GetNewAddress() +{ + Optional<std::string> wallet_name{}; + if (gArgs.IsArgSet("-rpcwallet")) wallet_name = gArgs.GetArg("-rpcwallet", ""); + DefaultRequestHandler rh; + return ConnectAndCallRPC(&rh, "getnewaddress", /* args=*/{}, wallet_name); +} + +/** + * Check bounds and set up args for RPC generatetoaddress params: nblocks, address, maxtries. + * @param[in] address Reference to const string address to insert into the args. + * @param args Reference to vector of string args to modify. + */ +static void SetGenerateToAddressArgs(const std::string& address, std::vector<std::string>& args) +{ + if (args.size() > 2) throw std::runtime_error("too many arguments (maximum 2 for nblocks and maxtries)"); + if (args.size() == 0) { + args.emplace_back(DEFAULT_NBLOCKS); + } else if (args.at(0) == "0") { + throw std::runtime_error("the first argument (number of blocks to generate, default: " + DEFAULT_NBLOCKS + ") must be an integer value greater than zero"); + } + args.emplace(args.begin() + 1, address); +} + static int CommandLineRPC(int argc, char *argv[]) { std::string strPrint; @@ -535,6 +618,15 @@ static int CommandLineRPC(int argc, char *argv[]) std::string method; if (gArgs.IsArgSet("-getinfo")) { rh.reset(new GetinfoRequestHandler()); + } else if (gArgs.GetBoolArg("-generate", false)) { + const UniValue getnewaddress{GetNewAddress()}; + const UniValue& error{find_value(getnewaddress, "error")}; + if (error.isNull()) { + SetGenerateToAddressArgs(find_value(getnewaddress, "result").get_str(), args); + rh.reset(new GenerateToAddressRequestHandler()); + } else { + ParseError(error, strPrint, nRet); + } } else { rh.reset(new DefaultRequestHandler()); if (args.size() < 1) { @@ -543,40 +635,22 @@ static int CommandLineRPC(int argc, char *argv[]) method = args[0]; args.erase(args.begin()); // Remove trailing method name from arguments vector } - Optional<std::string> wallet_name{}; - if (gArgs.IsArgSet("-rpcwallet")) wallet_name = gArgs.GetArg("-rpcwallet", ""); - const UniValue reply = ConnectAndCallRPC(rh.get(), method, args, wallet_name); - - // Parse reply - UniValue result = find_value(reply, "result"); - const UniValue& error = find_value(reply, "error"); - if (!error.isNull()) { - // Error - strPrint = "error: " + error.write(); - nRet = abs(error["code"].get_int()); - if (error.isObject()) { - const UniValue& errCode = find_value(error, "code"); - const UniValue& errMsg = find_value(error, "message"); - strPrint = errCode.isNull() ? "" : ("error code: " + errCode.getValStr() + "\n"); - - if (errMsg.isStr()) { - strPrint += ("error message:\n" + errMsg.get_str()); - } - if (errCode.isNum() && errCode.get_int() == RPC_WALLET_NOT_SPECIFIED) { - strPrint += "\nTry adding \"-rpcwallet=<filename>\" option to bitcoin-cli command line."; + if (nRet == 0) { + // Perform RPC call + Optional<std::string> wallet_name{}; + if (gArgs.IsArgSet("-rpcwallet")) wallet_name = gArgs.GetArg("-rpcwallet", ""); + const UniValue reply = ConnectAndCallRPC(rh.get(), method, args, wallet_name); + + // Parse reply + UniValue result = find_value(reply, "result"); + const UniValue& error = find_value(reply, "error"); + if (error.isNull()) { + if (gArgs.IsArgSet("-getinfo") && !gArgs.IsArgSet("-rpcwallet")) { + GetWalletBalances(result); // fetch multiwallet balances and append to result } - } - } else { - if (gArgs.IsArgSet("-getinfo") && !gArgs.IsArgSet("-rpcwallet")) { - GetWalletBalances(result); // fetch multiwallet balances and append to result - } - // Result - if (result.isNull()) { - strPrint = ""; - } else if (result.isStr()) { - strPrint = result.get_str(); + ParseResult(result, strPrint); } else { - strPrint = result.write(2); + ParseError(error, strPrint, nRet); } } } catch (const std::exception& e) { diff --git a/src/bitcoin-tx.cpp b/src/bitcoin-tx.cpp index f54a299a36..56afcb6ded 100644 --- a/src/bitcoin-tx.cpp +++ b/src/bitcoin-tx.cpp @@ -36,40 +36,40 @@ static const int CONTINUE_EXECUTION=-1; const std::function<std::string(const char*)> G_TRANSLATION_FUN = nullptr; -static void SetupBitcoinTxArgs() +static void SetupBitcoinTxArgs(ArgsManager &argsman) { - SetupHelpOptions(gArgs); - - gArgs.AddArg("-create", "Create new, empty TX.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-json", "Select JSON output", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-txid", "Output only the hex-encoded transaction id of the resultant transaction.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - SetupChainParamsBaseOptions(); - - gArgs.AddArg("delin=N", "Delete input N from TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("delout=N", "Delete output N from TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("in=TXID:VOUT(:SEQUENCE_NUMBER)", "Add input to TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("locktime=N", "Set TX lock time to N", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("nversion=N", "Set TX version to N", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("outaddr=VALUE:ADDRESS", "Add address-based output to TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("outdata=[VALUE:]DATA", "Add data-based output to TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("outmultisig=VALUE:REQUIRED:PUBKEYS:PUBKEY1:PUBKEY2:....[:FLAGS]", "Add Pay To n-of-m Multi-sig output to TX. n = REQUIRED, m = PUBKEYS. " + SetupHelpOptions(argsman); + + argsman.AddArg("-create", "Create new, empty TX.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-json", "Select JSON output", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-txid", "Output only the hex-encoded transaction id of the resultant transaction.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + SetupChainParamsBaseOptions(argsman); + + argsman.AddArg("delin=N", "Delete input N from TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("delout=N", "Delete output N from TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("in=TXID:VOUT(:SEQUENCE_NUMBER)", "Add input to TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("locktime=N", "Set TX lock time to N", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("nversion=N", "Set TX version to N", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("outaddr=VALUE:ADDRESS", "Add address-based output to TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("outdata=[VALUE:]DATA", "Add data-based output to TX", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("outmultisig=VALUE:REQUIRED:PUBKEYS:PUBKEY1:PUBKEY2:....[:FLAGS]", "Add Pay To n-of-m Multi-sig output to TX. n = REQUIRED, m = PUBKEYS. " "Optionally add the \"W\" flag to produce a pay-to-witness-script-hash output. " "Optionally add the \"S\" flag to wrap the output in a pay-to-script-hash.", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("outpubkey=VALUE:PUBKEY[:FLAGS]", "Add pay-to-pubkey output to TX. " + argsman.AddArg("outpubkey=VALUE:PUBKEY[:FLAGS]", "Add pay-to-pubkey output to TX. " "Optionally add the \"W\" flag to produce a pay-to-witness-pubkey-hash output. " "Optionally add the \"S\" flag to wrap the output in a pay-to-script-hash.", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("outscript=VALUE:SCRIPT[:FLAGS]", "Add raw script output to TX. " + argsman.AddArg("outscript=VALUE:SCRIPT[:FLAGS]", "Add raw script output to TX. " "Optionally add the \"W\" flag to produce a pay-to-witness-script-hash output. " "Optionally add the \"S\" flag to wrap the output in a pay-to-script-hash.", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("replaceable(=N)", "Set RBF opt-in sequence number for input N (if not provided, opt-in all available inputs)", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("sign=SIGHASH-FLAGS", "Add zero or more signatures to transaction. " + argsman.AddArg("replaceable(=N)", "Set RBF opt-in sequence number for input N (if not provided, opt-in all available inputs)", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("sign=SIGHASH-FLAGS", "Add zero or more signatures to transaction. " "This command requires JSON registers:" "prevtxs=JSON object, " "privatekeys=JSON object. " "See signrawtransactionwithkey docs for format of sighash flags, JSON objects.", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("load=NAME:FILENAME", "Load JSON file FILENAME into register NAME", ArgsManager::ALLOW_ANY, OptionsCategory::REGISTER_COMMANDS); - gArgs.AddArg("set=NAME:JSON-STRING", "Set register NAME to given JSON-STRING", ArgsManager::ALLOW_ANY, OptionsCategory::REGISTER_COMMANDS); + argsman.AddArg("load=NAME:FILENAME", "Load JSON file FILENAME into register NAME", ArgsManager::ALLOW_ANY, OptionsCategory::REGISTER_COMMANDS); + argsman.AddArg("set=NAME:JSON-STRING", "Set register NAME to given JSON-STRING", ArgsManager::ALLOW_ANY, OptionsCategory::REGISTER_COMMANDS); } // @@ -81,7 +81,7 @@ static int AppInitRawTx(int argc, char* argv[]) // // Parameters // - SetupBitcoinTxArgs(); + SetupBitcoinTxArgs(gArgs); std::string error; if (!gArgs.ParseParameters(argc, argv, error)) { tfm::format(std::cerr, "Error parsing command line arguments: %s\n", error); diff --git a/src/bitcoin-wallet.cpp b/src/bitcoin-wallet.cpp index b420463c00..06b0c86476 100644 --- a/src/bitcoin-wallet.cpp +++ b/src/bitcoin-wallet.cpp @@ -19,24 +19,24 @@ const std::function<std::string(const char*)> G_TRANSLATION_FUN = nullptr; UrlDecodeFn* const URL_DECODE = nullptr; -static void SetupWalletToolArgs() +static void SetupWalletToolArgs(ArgsManager& argsman) { - SetupHelpOptions(gArgs); - SetupChainParamsBaseOptions(); + SetupHelpOptions(argsman); + SetupChainParamsBaseOptions(argsman); - gArgs.AddArg("-datadir=<dir>", "Specify data directory", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-wallet=<wallet-name>", "Specify wallet name", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::OPTIONS); - gArgs.AddArg("-debug=<category>", "Output debugging information (default: 0).", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-printtoconsole", "Send trace/debug info to console (default: 1 when no -debug is true, 0 otherwise).", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-datadir=<dir>", "Specify data directory", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-wallet=<wallet-name>", "Specify wallet name", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::OPTIONS); + argsman.AddArg("-debug=<category>", "Output debugging information (default: 0).", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-printtoconsole", "Send trace/debug info to console (default: 1 when no -debug is true, 0 otherwise).", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("info", "Get wallet info", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("create", "Create new wallet file", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); - gArgs.AddArg("salvage", "Attempt to recover private keys from a corrupt wallet", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("info", "Get wallet info", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("create", "Create new wallet file", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); + argsman.AddArg("salvage", "Attempt to recover private keys from a corrupt wallet", ArgsManager::ALLOW_ANY, OptionsCategory::COMMANDS); } static bool WalletAppInit(int argc, char* argv[]) { - SetupWalletToolArgs(); + SetupWalletToolArgs(gArgs); std::string error_message; if (!gArgs.ParseParameters(argc, argv, error_message)) { tfm::format(std::cerr, "Error parsing command line arguments: %s\n", error_message); diff --git a/src/bitcoind.cpp b/src/bitcoind.cpp index b8e8717896..b04cc12059 100644 --- a/src/bitcoind.cpp +++ b/src/bitcoind.cpp @@ -13,9 +13,9 @@ #include <init.h> #include <interfaces/chain.h> #include <node/context.h> +#include <node/ui_interface.h> #include <noui.h> #include <shutdown.h> -#include <ui_interface.h> #include <util/ref.h> #include <util/strencodings.h> #include <util/system.h> @@ -101,6 +101,11 @@ static bool AppInit(int argc, char* argv[]) } } + if (!gArgs.InitSettings(error)) { + InitError(Untranslated(error)); + return false; + } + // -server defaults to true for bitcoind but not for the GUI so do this here gArgs.SoftSetBoolArg("-server", true); // Set this early so that parameter interactions go to console diff --git a/src/blockfilter.cpp b/src/blockfilter.cpp index 5f5bed5bda..9a6fb4abd0 100644 --- a/src/blockfilter.cpp +++ b/src/blockfilter.cpp @@ -291,7 +291,7 @@ uint256 BlockFilter::GetHash() const const std::vector<unsigned char>& data = GetEncodedFilter(); uint256 result; - CHash256().Write(data.data(), data.size()).Finalize(result.begin()); + CHash256().Write(data).Finalize(result); return result; } @@ -301,8 +301,8 @@ uint256 BlockFilter::ComputeHeader(const uint256& prev_header) const uint256 result; CHash256() - .Write(filter_hash.begin(), filter_hash.size()) - .Write(prev_header.begin(), prev_header.size()) - .Finalize(result.begin()); + .Write(filter_hash) + .Write(prev_header) + .Finalize(result); return result; } diff --git a/src/bloom.cpp b/src/bloom.cpp index 54fcf487e4..d182f0728e 100644 --- a/src/bloom.cpp +++ b/src/bloom.cpp @@ -135,8 +135,8 @@ bool CBloomFilter::IsRelevantAndUpdate(const CTransaction& tx) else if ((nFlags & BLOOM_UPDATE_MASK) == BLOOM_UPDATE_P2PUBKEY_ONLY) { std::vector<std::vector<unsigned char> > vSolutions; - txnouttype type = Solver(txout.scriptPubKey, vSolutions); - if (type == TX_PUBKEY || type == TX_MULTISIG) { + TxoutType type = Solver(txout.scriptPubKey, vSolutions); + if (type == TxoutType::PUBKEY || type == TxoutType::MULTISIG) { insert(COutPoint(hash, i)); } } diff --git a/src/chainparams.cpp b/src/chainparams.cpp index 799474fae2..ffd2076c9a 100644 --- a/src/chainparams.cpp +++ b/src/chainparams.cpp @@ -110,7 +110,7 @@ public: // Note that of those which support the service bits prefix, most only support a subset of // possible options. - // This is fine at runtime as we'll fall back to using them as a oneshot if they don't support the + // This is fine at runtime as we'll fall back to using them as an addrfetch if they don't support the // service bits we want, but we should get them updated to support all service bits wanted by any // release ASAP to avoid it where possible. vSeeds.emplace_back("seed.bitcoin.sipa.be"); // Pieter Wuille, only supports x1, x5, x9, and xd @@ -121,6 +121,7 @@ public: vSeeds.emplace_back("seed.btc.petertodd.org"); // Peter Todd, only supports x1, x5, x9, and xd vSeeds.emplace_back("seed.bitcoin.sprovoost.nl"); // Sjors Provoost vSeeds.emplace_back("dnsseed.emzy.de"); // Stephan Oeste + vSeeds.emplace_back("seed.bitcoin.wiz.biz"); // Jason Maurice base58Prefixes[PUBKEY_ADDRESS] = std::vector<unsigned char>(1,0); base58Prefixes[SCRIPT_ADDRESS] = std::vector<unsigned char>(1,5); @@ -340,8 +341,8 @@ public: void CRegTestParams::UpdateActivationParametersFromArgs(const ArgsManager& args) { - if (gArgs.IsArgSet("-segwitheight")) { - int64_t height = gArgs.GetArg("-segwitheight", consensus.SegwitHeight); + if (args.IsArgSet("-segwitheight")) { + int64_t height = args.GetArg("-segwitheight", consensus.SegwitHeight); if (height < -1 || height >= std::numeric_limits<int>::max()) { throw std::runtime_error(strprintf("Activation height %ld for segwit is out of valid range. Use -1 to disable segwit.", height)); } else if (height == -1) { diff --git a/src/chainparamsbase.cpp b/src/chainparamsbase.cpp index 894b8553c4..1825ced640 100644 --- a/src/chainparamsbase.cpp +++ b/src/chainparamsbase.cpp @@ -15,14 +15,14 @@ const std::string CBaseChainParams::MAIN = "main"; const std::string CBaseChainParams::TESTNET = "test"; const std::string CBaseChainParams::REGTEST = "regtest"; -void SetupChainParamsBaseOptions() +void SetupChainParamsBaseOptions(ArgsManager& argsman) { - gArgs.AddArg("-chain=<chain>", "Use the chain <chain> (default: main). Allowed values: main, test, regtest", ArgsManager::ALLOW_ANY, OptionsCategory::CHAINPARAMS); - gArgs.AddArg("-regtest", "Enter regression test mode, which uses a special chain in which blocks can be solved instantly. " + argsman.AddArg("-chain=<chain>", "Use the chain <chain> (default: main). Allowed values: main, test, regtest", ArgsManager::ALLOW_ANY, OptionsCategory::CHAINPARAMS); + argsman.AddArg("-regtest", "Enter regression test mode, which uses a special chain in which blocks can be solved instantly. " "This is intended for regression testing tools and app development. Equivalent to -chain=regtest.", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::CHAINPARAMS); - gArgs.AddArg("-segwitheight=<n>", "Set the activation height of segwit. -1 to disable. (regtest-only)", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-testnet", "Use the test chain. Equivalent to -chain=test.", ArgsManager::ALLOW_ANY, OptionsCategory::CHAINPARAMS); - gArgs.AddArg("-vbparams=deployment:start:end", "Use given start/end times for specified version bits deployment (regtest-only)", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::CHAINPARAMS); + argsman.AddArg("-segwitheight=<n>", "Set the activation height of segwit. -1 to disable. (regtest-only)", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-testnet", "Use the test chain. Equivalent to -chain=test.", ArgsManager::ALLOW_ANY, OptionsCategory::CHAINPARAMS); + argsman.AddArg("-vbparams=deployment:start:end", "Use given start/end times for specified version bits deployment (regtest-only)", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::CHAINPARAMS); } static std::unique_ptr<CBaseChainParams> globalChainBaseParams; diff --git a/src/chainparamsbase.h b/src/chainparamsbase.h index 3c139931ea..1c52d0ea97 100644 --- a/src/chainparamsbase.h +++ b/src/chainparamsbase.h @@ -8,6 +8,8 @@ #include <memory> #include <string> +class ArgsManager; + /** * CBaseChainParams defines the base parameters (shared between bitcoin-cli and bitcoind) * of a given instance of the Bitcoin system. @@ -43,7 +45,7 @@ std::unique_ptr<CBaseChainParams> CreateBaseChainParams(const std::string& chain /** *Set the arguments for chainparams */ -void SetupChainParamsBaseOptions(); +void SetupChainParamsBaseOptions(ArgsManager& argsman); /** * Return the currently selected parameters. This won't change after app diff --git a/src/coins.cpp b/src/coins.cpp index 7b76c13f98..5de2ed7810 100644 --- a/src/coins.cpp +++ b/src/coins.cpp @@ -245,6 +245,14 @@ bool CCoinsViewCache::HaveInputs(const CTransaction& tx) const return true; } +void CCoinsViewCache::ReallocateCache() +{ + // Cache should be empty when we're calling this. + assert(cacheCoins.size() == 0); + cacheCoins.~CCoinsMap(); + ::new (&cacheCoins) CCoinsMap(); +} + static const size_t MIN_TRANSACTION_OUTPUT_WEIGHT = WITNESS_SCALE_FACTOR * ::GetSerializeSize(CTxOut(), PROTOCOL_VERSION); static const size_t MAX_OUTPUTS_PER_BLOCK = MAX_BLOCK_WEIGHT / MIN_TRANSACTION_OUTPUT_WEIGHT; diff --git a/src/coins.h b/src/coins.h index a3f34bb0ee..a3e241ac90 100644 --- a/src/coins.h +++ b/src/coins.h @@ -318,6 +318,13 @@ public: //! Check whether all prevouts of the transaction are present in the UTXO set represented by this view bool HaveInputs(const CTransaction& tx) const; + //! Force a reallocation of the cache map. This is required when downsizing + //! the cache because the map's allocator may be hanging onto a lot of + //! memory despite having called .clear(). + //! + //! See: https://stackoverflow.com/questions/42114044/how-to-release-unordered-map-memory + void ReallocateCache(); + private: /** * @note this is marked const, but may actually append to `cacheCoins`, increasing diff --git a/src/consensus/tx_verify.cpp b/src/consensus/tx_verify.cpp index 81245e3e11..9e8e6530f1 100644 --- a/src/consensus/tx_verify.cpp +++ b/src/consensus/tx_verify.cpp @@ -27,9 +27,9 @@ bool IsFinalTx(const CTransaction &tx, int nBlockHeight, int64_t nBlockTime) return true; } -std::pair<int, int64_t> CalculateSequenceLocks(const CTransaction &tx, int flags, std::vector<int>* prevHeights, const CBlockIndex& block) +std::pair<int, int64_t> CalculateSequenceLocks(const CTransaction &tx, int flags, std::vector<int>& prevHeights, const CBlockIndex& block) { - assert(prevHeights->size() == tx.vin.size()); + assert(prevHeights.size() == tx.vin.size()); // Will be set to the equivalent height- and time-based nLockTime // values that would be necessary to satisfy all relative lock- @@ -59,11 +59,11 @@ std::pair<int, int64_t> CalculateSequenceLocks(const CTransaction &tx, int flags // consensus-enforced meaning at this point. if (txin.nSequence & CTxIn::SEQUENCE_LOCKTIME_DISABLE_FLAG) { // The height of this input is not relevant for sequence locks - (*prevHeights)[txinIndex] = 0; + prevHeights[txinIndex] = 0; continue; } - int nCoinHeight = (*prevHeights)[txinIndex]; + int nCoinHeight = prevHeights[txinIndex]; if (txin.nSequence & CTxIn::SEQUENCE_LOCKTIME_TYPE_FLAG) { int64_t nCoinTime = block.GetAncestor(std::max(nCoinHeight-1, 0))->GetMedianTimePast(); @@ -99,7 +99,7 @@ bool EvaluateSequenceLocks(const CBlockIndex& block, std::pair<int, int64_t> loc return true; } -bool SequenceLocks(const CTransaction &tx, int flags, std::vector<int>* prevHeights, const CBlockIndex& block) +bool SequenceLocks(const CTransaction &tx, int flags, std::vector<int>& prevHeights, const CBlockIndex& block) { return EvaluateSequenceLocks(block, CalculateSequenceLocks(tx, flags, prevHeights, block)); } diff --git a/src/consensus/tx_verify.h b/src/consensus/tx_verify.h index ffcaf3cab1..e2a9328df8 100644 --- a/src/consensus/tx_verify.h +++ b/src/consensus/tx_verify.h @@ -66,13 +66,13 @@ bool IsFinalTx(const CTransaction &tx, int nBlockHeight, int64_t nBlockTime); * Also removes from the vector of input heights any entries which did not * correspond to sequence locked inputs as they do not affect the calculation. */ -std::pair<int, int64_t> CalculateSequenceLocks(const CTransaction &tx, int flags, std::vector<int>* prevHeights, const CBlockIndex& block); +std::pair<int, int64_t> CalculateSequenceLocks(const CTransaction &tx, int flags, std::vector<int>& prevHeights, const CBlockIndex& block); bool EvaluateSequenceLocks(const CBlockIndex& block, std::pair<int, int64_t> lockPair); /** * Check if transaction is final per BIP 68 sequence numbers and can be included in a block. * Consensus critical. Takes as input a list of heights at which tx's inputs (in order) confirmed. */ -bool SequenceLocks(const CTransaction &tx, int flags, std::vector<int>* prevHeights, const CBlockIndex& block); +bool SequenceLocks(const CTransaction &tx, int flags, std::vector<int>& prevHeights, const CBlockIndex& block); #endif // BITCOIN_CONSENSUS_TX_VERIFY_H diff --git a/src/consensus/validation.h b/src/consensus/validation.h index a79e7b9d12..2a93a090d6 100644 --- a/src/consensus/validation.h +++ b/src/consensus/validation.h @@ -26,16 +26,21 @@ enum class TxValidationResult { * is uninteresting. */ TX_RECENT_CONSENSUS_CHANGE, - TX_NOT_STANDARD, //!< didn't meet our local policy rules + TX_INPUTS_NOT_STANDARD, //!< inputs (covered by txid) failed policy rules + TX_NOT_STANDARD, //!< otherwise didn't meet our local policy rules TX_MISSING_INPUTS, //!< transaction was missing some of its inputs TX_PREMATURE_SPEND, //!< transaction spends a coinbase too early, or violates locktime/sequence locks /** - * Transaction might be missing a witness, have a witness prior to SegWit + * Transaction might have a witness prior to SegWit * activation, or witness may have been malleated (which includes * non-standard witnesses). */ TX_WITNESS_MUTATED, /** + * Transaction is missing a witness. + */ + TX_WITNESS_STRIPPED, + /** * Tx already in mempool or conflicts with a tx in the chain * (if it conflicts with another tx in mempool, we use MEMPOOL_POLICY as it failed to reach the RBF threshold) * Currently this is only used if the transaction already exists in the mempool or on chain. @@ -75,37 +80,39 @@ enum class BlockValidationResult { * by TxValidationState and BlockValidationState for validation information on transactions * and blocks respectively. */ template <typename Result> -class ValidationState { +class ValidationState +{ private: - enum mode_state { - MODE_VALID, //!< everything ok - MODE_INVALID, //!< network rule violation (DoS value may be set) - MODE_ERROR, //!< run-time error - } m_mode{MODE_VALID}; + enum class ModeState { + M_VALID, //!< everything ok + M_INVALID, //!< network rule violation (DoS value may be set) + M_ERROR, //!< run-time error + } m_mode{ModeState::M_VALID}; Result m_result{}; std::string m_reject_reason; std::string m_debug_message; + public: bool Invalid(Result result, - const std::string &reject_reason="", - const std::string &debug_message="") + const std::string& reject_reason = "", + const std::string& debug_message = "") { m_result = result; m_reject_reason = reject_reason; m_debug_message = debug_message; - if (m_mode != MODE_ERROR) m_mode = MODE_INVALID; + if (m_mode != ModeState::M_ERROR) m_mode = ModeState::M_INVALID; return false; } bool Error(const std::string& reject_reason) { - if (m_mode == MODE_VALID) + if (m_mode == ModeState::M_VALID) m_reject_reason = reject_reason; - m_mode = MODE_ERROR; + m_mode = ModeState::M_ERROR; return false; } - bool IsValid() const { return m_mode == MODE_VALID; } - bool IsInvalid() const { return m_mode == MODE_INVALID; } - bool IsError() const { return m_mode == MODE_ERROR; } + bool IsValid() const { return m_mode == ModeState::M_VALID; } + bool IsInvalid() const { return m_mode == ModeState::M_INVALID; } + bool IsError() const { return m_mode == ModeState::M_ERROR; } Result GetResult() const { return m_result; } std::string GetRejectReason() const { return m_reject_reason; } std::string GetDebugMessage() const { return m_debug_message; } diff --git a/src/core_write.cpp b/src/core_write.cpp index cb1fc214eb..f9d918cb6d 100644 --- a/src/core_write.cpp +++ b/src/core_write.cpp @@ -48,13 +48,14 @@ std::string FormatScript(const CScript& script) } } if (vch.size() > 0) { - ret += strprintf("0x%x 0x%x ", HexStr(it2, it - vch.size()), HexStr(it - vch.size(), it)); + ret += strprintf("0x%x 0x%x ", HexStr(std::vector<uint8_t>(it2, it - vch.size())), + HexStr(std::vector<uint8_t>(it - vch.size(), it))); } else { - ret += strprintf("0x%x ", HexStr(it2, it)); + ret += strprintf("0x%x ", HexStr(std::vector<uint8_t>(it2, it))); } continue; } - ret += strprintf("0x%x ", HexStr(it2, script.end())); + ret += strprintf("0x%x ", HexStr(std::vector<uint8_t>(it2, script.end()))); break; } return ret.substr(0, ret.size() - 1); @@ -131,20 +132,20 @@ std::string EncodeHexTx(const CTransaction& tx, const int serializeFlags) { CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION | serializeFlags); ssTx << tx; - return HexStr(ssTx.begin(), ssTx.end()); + return HexStr(ssTx); } void ScriptToUniv(const CScript& script, UniValue& out, bool include_address) { out.pushKV("asm", ScriptToAsmStr(script)); - out.pushKV("hex", HexStr(script.begin(), script.end())); + out.pushKV("hex", HexStr(script)); std::vector<std::vector<unsigned char>> solns; - txnouttype type = Solver(script, solns); + TxoutType type = Solver(script, solns); out.pushKV("type", GetTxnOutputType(type)); CTxDestination address; - if (include_address && ExtractDestination(script, address) && type != TX_PUBKEY) { + if (include_address && ExtractDestination(script, address) && type != TxoutType::PUBKEY) { out.pushKV("address", EncodeDestination(address)); } } @@ -152,15 +153,15 @@ void ScriptToUniv(const CScript& script, UniValue& out, bool include_address) void ScriptPubKeyToUniv(const CScript& scriptPubKey, UniValue& out, bool fIncludeHex) { - txnouttype type; + TxoutType type; std::vector<CTxDestination> addresses; int nRequired; out.pushKV("asm", ScriptToAsmStr(scriptPubKey)); if (fIncludeHex) - out.pushKV("hex", HexStr(scriptPubKey.begin(), scriptPubKey.end())); + out.pushKV("hex", HexStr(scriptPubKey)); - if (!ExtractDestinations(scriptPubKey, type, addresses, nRequired) || type == TX_PUBKEY) { + if (!ExtractDestinations(scriptPubKey, type, addresses, nRequired) || type == TxoutType::PUBKEY) { out.pushKV("type", GetTxnOutputType(type)); return; } @@ -179,7 +180,9 @@ void TxToUniv(const CTransaction& tx, const uint256& hashBlock, UniValue& entry, { entry.pushKV("txid", tx.GetHash().GetHex()); entry.pushKV("hash", tx.GetWitnessHash().GetHex()); - entry.pushKV("version", tx.nVersion); + // Transaction version is actually unsigned in consensus checks, just signed in memory, + // so cast to unsigned before giving it to the user. + entry.pushKV("version", static_cast<int64_t>(static_cast<uint32_t>(tx.nVersion))); entry.pushKV("size", (int)::GetSerializeSize(tx, PROTOCOL_VERSION)); entry.pushKV("vsize", (GetTransactionWeight(tx) + WITNESS_SCALE_FACTOR - 1) / WITNESS_SCALE_FACTOR); entry.pushKV("weight", GetTransactionWeight(tx)); @@ -190,21 +193,21 @@ void TxToUniv(const CTransaction& tx, const uint256& hashBlock, UniValue& entry, const CTxIn& txin = tx.vin[i]; UniValue in(UniValue::VOBJ); if (tx.IsCoinBase()) - in.pushKV("coinbase", HexStr(txin.scriptSig.begin(), txin.scriptSig.end())); + in.pushKV("coinbase", HexStr(txin.scriptSig)); else { in.pushKV("txid", txin.prevout.hash.GetHex()); in.pushKV("vout", (int64_t)txin.prevout.n); UniValue o(UniValue::VOBJ); o.pushKV("asm", ScriptToAsmStr(txin.scriptSig, true)); - o.pushKV("hex", HexStr(txin.scriptSig.begin(), txin.scriptSig.end())); + o.pushKV("hex", HexStr(txin.scriptSig)); in.pushKV("scriptSig", o); - if (!tx.vin[i].scriptWitness.IsNull()) { - UniValue txinwitness(UniValue::VARR); - for (const auto& item : tx.vin[i].scriptWitness.stack) { - txinwitness.push_back(HexStr(item.begin(), item.end())); - } - in.pushKV("txinwitness", txinwitness); + } + if (!tx.vin[i].scriptWitness.IsNull()) { + UniValue txinwitness(UniValue::VARR); + for (const auto& item : tx.vin[i].scriptWitness.stack) { + txinwitness.push_back(HexStr(item)); } + in.pushKV("txinwitness", txinwitness); } in.pushKV("sequence", (int64_t)txin.nSequence); vin.push_back(in); diff --git a/src/crypto/common.h b/src/crypto/common.h index e7bb020a19..5b4932c992 100644 --- a/src/crypto/common.h +++ b/src/crypto/common.h @@ -82,12 +82,12 @@ void static inline WriteBE64(unsigned char* ptr, uint64_t x) /** Return the smallest number n such that (x >> n) == 0 (or 64 if the highest bit in x is set. */ uint64_t static inline CountBits(uint64_t x) { -#if HAVE_DECL___BUILTIN_CLZL +#if HAVE_BUILTIN_CLZL if (sizeof(unsigned long) >= sizeof(uint64_t)) { return x ? 8 * sizeof(unsigned long) - __builtin_clzl(x) : 0; } #endif -#if HAVE_DECL___BUILTIN_CLZLL +#if HAVE_BUILTIN_CLZLL if (sizeof(unsigned long long) >= sizeof(uint64_t)) { return x ? 8 * sizeof(unsigned long long) - __builtin_clzll(x) : 0; } diff --git a/src/dbwrapper.h b/src/dbwrapper.h index 116d7d8679..215b033708 100644 --- a/src/dbwrapper.h +++ b/src/dbwrapper.h @@ -292,18 +292,6 @@ public: // Get an estimate of LevelDB memory usage (in bytes). size_t DynamicMemoryUsage() const; - // not available for LevelDB; provide for compatibility with BDB - bool Flush() - { - return true; - } - - bool Sync() - { - CDBBatch batch(*this); - return WriteBatch(batch, true); - } - CDBIterator *NewIterator() { return new CDBIterator(*this, pdb->NewIterator(iteroptions)); diff --git a/src/dummywallet.cpp b/src/dummywallet.cpp index 0f7848bae1..18dc7a69e2 100644 --- a/src/dummywallet.cpp +++ b/src/dummywallet.cpp @@ -20,14 +20,14 @@ class DummyWalletInit : public WalletInitInterface { public: bool HasWalletSupport() const override {return false;} - void AddWalletOptions() const override; + void AddWalletOptions(ArgsManager& argsman) const override; bool ParameterInteraction() const override {return true;} void Construct(NodeContext& node) const override {LogPrintf("No wallet support compiled in!\n");} }; -void DummyWalletInit::AddWalletOptions() const +void DummyWalletInit::AddWalletOptions(ArgsManager& argsman) const { - gArgs.AddHiddenArgs({ + argsman.AddHiddenArgs({ "-addresstype", "-avoidpartialspends", "-changetype", diff --git a/src/fs.cpp b/src/fs.cpp index e68c97b3ca..eef9c81de9 100644 --- a/src/fs.cpp +++ b/src/fs.cpp @@ -5,10 +5,12 @@ #include <fs.h> #ifndef WIN32 +#include <cstring> #include <fcntl.h> #include <string> #include <sys/file.h> #include <sys/utsname.h> +#include <unistd.h> #else #ifndef NOMINMAX #define NOMINMAX @@ -31,7 +33,8 @@ FILE *fopen(const fs::path& p, const char *mode) #ifndef WIN32 -static std::string GetErrorReason() { +static std::string GetErrorReason() +{ return std::strerror(errno); } diff --git a/src/hash.cpp b/src/hash.cpp index 26150e5ca8..4c09f5f646 100644 --- a/src/hash.cpp +++ b/src/hash.cpp @@ -12,7 +12,7 @@ inline uint32_t ROTL32(uint32_t x, int8_t r) return (x << r) | (x >> (32 - r)); } -unsigned int MurmurHash3(unsigned int nHashSeed, const std::vector<unsigned char>& vDataToHash) +unsigned int MurmurHash3(unsigned int nHashSeed, Span<const unsigned char> vDataToHash) { // The following is MurmurHash3 (x86_32), see http://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp uint32_t h1 = nHashSeed; diff --git a/src/hash.h b/src/hash.h index c295568a3e..71806483ff 100644 --- a/src/hash.h +++ b/src/hash.h @@ -25,14 +25,15 @@ private: public: static const size_t OUTPUT_SIZE = CSHA256::OUTPUT_SIZE; - void Finalize(unsigned char hash[OUTPUT_SIZE]) { + void Finalize(Span<unsigned char> output) { + assert(output.size() == OUTPUT_SIZE); unsigned char buf[CSHA256::OUTPUT_SIZE]; sha.Finalize(buf); - sha.Reset().Write(buf, CSHA256::OUTPUT_SIZE).Finalize(hash); + sha.Reset().Write(buf, CSHA256::OUTPUT_SIZE).Finalize(output.data()); } - CHash256& Write(const unsigned char *data, size_t len) { - sha.Write(data, len); + CHash256& Write(Span<const unsigned char> input) { + sha.Write(input.data(), input.size()); return *this; } @@ -49,14 +50,15 @@ private: public: static const size_t OUTPUT_SIZE = CRIPEMD160::OUTPUT_SIZE; - void Finalize(unsigned char hash[OUTPUT_SIZE]) { + void Finalize(Span<unsigned char> output) { + assert(output.size() == OUTPUT_SIZE); unsigned char buf[CSHA256::OUTPUT_SIZE]; sha.Finalize(buf); - CRIPEMD160().Write(buf, CSHA256::OUTPUT_SIZE).Finalize(hash); + CRIPEMD160().Write(buf, CSHA256::OUTPUT_SIZE).Finalize(output.data()); } - CHash160& Write(const unsigned char *data, size_t len) { - sha.Write(data, len); + CHash160& Write(Span<const unsigned char> input) { + sha.Write(input.data(), input.size()); return *this; } @@ -67,52 +69,31 @@ public: }; /** Compute the 256-bit hash of an object. */ -template<typename T1> -inline uint256 Hash(const T1 pbegin, const T1 pend) +template<typename T> +inline uint256 Hash(const T& in1) { - static const unsigned char pblank[1] = {}; uint256 result; - CHash256().Write(pbegin == pend ? pblank : (const unsigned char*)&pbegin[0], (pend - pbegin) * sizeof(pbegin[0])) - .Finalize((unsigned char*)&result); + CHash256().Write(MakeUCharSpan(in1)).Finalize(result); return result; } /** Compute the 256-bit hash of the concatenation of two objects. */ template<typename T1, typename T2> -inline uint256 Hash(const T1 p1begin, const T1 p1end, - const T2 p2begin, const T2 p2end) { - static const unsigned char pblank[1] = {}; +inline uint256 Hash(const T1& in1, const T2& in2) { uint256 result; - CHash256().Write(p1begin == p1end ? pblank : (const unsigned char*)&p1begin[0], (p1end - p1begin) * sizeof(p1begin[0])) - .Write(p2begin == p2end ? pblank : (const unsigned char*)&p2begin[0], (p2end - p2begin) * sizeof(p2begin[0])) - .Finalize((unsigned char*)&result); + CHash256().Write(MakeUCharSpan(in1)).Write(MakeUCharSpan(in2)).Finalize(result); return result; } /** Compute the 160-bit hash an object. */ template<typename T1> -inline uint160 Hash160(const T1 pbegin, const T1 pend) +inline uint160 Hash160(const T1& in1) { - static unsigned char pblank[1] = {}; uint160 result; - CHash160().Write(pbegin == pend ? pblank : (const unsigned char*)&pbegin[0], (pend - pbegin) * sizeof(pbegin[0])) - .Finalize((unsigned char*)&result); + CHash160().Write(MakeUCharSpan(in1)).Finalize(result); return result; } -/** Compute the 160-bit hash of a vector. */ -inline uint160 Hash160(const std::vector<unsigned char>& vch) -{ - return Hash160(vch.begin(), vch.end()); -} - -/** Compute the 160-bit hash of a vector. */ -template<unsigned int N> -inline uint160 Hash160(const prevector<N, unsigned char>& vch) -{ - return Hash160(vch.begin(), vch.end()); -} - /** A writer stream (for serialization) that computes a 256-bit hash. */ class CHashWriter { @@ -129,13 +110,13 @@ public: int GetVersion() const { return nVersion; } void write(const char *pch, size_t size) { - ctx.Write((const unsigned char*)pch, size); + ctx.Write({(const unsigned char*)pch, size}); } // invalidates the object uint256 GetHash() { uint256 result; - ctx.Finalize((unsigned char*)&result); + ctx.Finalize(result); return result; } @@ -200,7 +181,7 @@ uint256 SerializeHash(const T& obj, int nType=SER_GETHASH, int nVersion=PROTOCOL return ss.GetHash(); } -unsigned int MurmurHash3(unsigned int nHashSeed, const std::vector<unsigned char>& vDataToHash); +unsigned int MurmurHash3(unsigned int nHashSeed, Span<const unsigned char> vDataToHash); void BIP32Hash(const ChainCode &chainCode, unsigned int nChild, unsigned char header, const unsigned char data[32], unsigned char output[64]); diff --git a/src/httpserver.cpp b/src/httpserver.cpp index 5e78fd1d71..1e5ea2de83 100644 --- a/src/httpserver.cpp +++ b/src/httpserver.cpp @@ -7,10 +7,10 @@ #include <chainparamsbase.h> #include <compat.h> #include <netbase.h> +#include <node/ui_interface.h> #include <rpc/protocol.h> // For HTTP status codes #include <shutdown.h> #include <sync.h> -#include <ui_interface.h> #include <util/strencodings.h> #include <util/system.h> #include <util/threadnames.h> diff --git a/src/index/base.cpp b/src/index/base.cpp index 74ea421e13..f587205a28 100644 --- a/src/index/base.cpp +++ b/src/index/base.cpp @@ -4,9 +4,9 @@ #include <chainparams.h> #include <index/base.h> +#include <node/ui_interface.h> #include <shutdown.h> #include <tinyformat.h> -#include <ui_interface.h> #include <util/system.h> #include <util/translation.h> #include <validation.h> @@ -17,15 +17,13 @@ constexpr char DB_BEST_BLOCK = 'B'; constexpr int64_t SYNC_LOG_INTERVAL = 30; // seconds constexpr int64_t SYNC_LOCATOR_WRITE_INTERVAL = 30; // seconds -template<typename... Args> +template <typename... Args> static void FatalError(const char* fmt, const Args&... args) { std::string strMessage = tfm::format(fmt, args...); - SetMiscWarning(strMessage); + SetMiscWarning(Untranslated(strMessage)); LogPrintf("*** %s\n", strMessage); - uiInterface.ThreadSafeMessageBox( - Untranslated("Error: A fatal internal error occurred, see debug.log for details"), - "", CClientUIInterface::MSG_ERROR); + AbortError(_("A fatal internal error occurred, see debug.log for details")); StartShutdown(); } diff --git a/src/index/txindex.cpp b/src/index/txindex.cpp index 4626395ef0..64472714cc 100644 --- a/src/index/txindex.cpp +++ b/src/index/txindex.cpp @@ -3,14 +3,12 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <index/txindex.h> +#include <node/ui_interface.h> #include <shutdown.h> -#include <ui_interface.h> #include <util/system.h> #include <util/translation.h> #include <validation.h> -#include <boost/thread.hpp> - constexpr char DB_BEST_BLOCK = 'B'; constexpr char DB_TXINDEX = 't'; constexpr char DB_TXINDEX_BLOCK = 'T'; @@ -150,7 +148,6 @@ bool TxIndex::DB::MigrateData(CBlockTreeDB& block_tree_db, const CBlockLocator& bool interrupted = false; std::unique_ptr<CDBIterator> cursor(block_tree_db.NewIterator()); for (cursor->Seek(begin_key); cursor->Valid(); cursor->Next()) { - boost::this_thread::interruption_point(); if (ShutdownRequested()) { interrupted = true; break; diff --git a/src/init.cpp b/src/init.cpp index 24ddaac066..f5aef08211 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -18,11 +18,13 @@ #include <compat/sanity.h> #include <consensus/validation.h> #include <fs.h> +#include <hash.h> #include <httprpc.h> #include <httpserver.h> #include <index/blockfilterindex.h> #include <index/txindex.h> #include <interfaces/chain.h> +#include <interfaces/node.h> #include <key.h> #include <miner.h> #include <net.h> @@ -30,6 +32,7 @@ #include <net_processing.h> #include <netbase.h> #include <node/context.h> +#include <node/ui_interface.h> #include <policy/feerate.h> #include <policy/fees.h> #include <policy/policy.h> @@ -42,19 +45,19 @@ #include <script/sigcache.h> #include <script/standard.h> #include <shutdown.h> +#include <sync.h> #include <timedata.h> #include <torcontrol.h> #include <txdb.h> #include <txmempool.h> -#include <ui_interface.h> #include <util/asmap.h> +#include <util/check.h> #include <util/moneystr.h> +#include <util/string.h> #include <util/system.h> #include <util/threadnames.h> #include <util/translation.h> #include <validation.h> -#include <hash.h> - #include <validationinterface.h> #include <walletinitinterface.h> @@ -71,11 +74,9 @@ #include <sys/stat.h> #endif -#include <boost/algorithm/string/classification.hpp> #include <boost/algorithm/string/replace.hpp> -#include <boost/algorithm/string/split.hpp> #include <boost/signals2/signal.hpp> -#include <boost/thread.hpp> +#include <boost/thread/thread.hpp> #if ENABLE_ZMQ #include <zmq/zmqabstractnotifier.h> @@ -153,6 +154,8 @@ NODISCARD static bool CreatePidFile() static std::unique_ptr<ECCVerifyHandle> globalVerifyHandle; +static std::thread g_load_block; + static boost::thread_group threadGroup; void Interrupt(NodeContext& node) @@ -173,11 +176,10 @@ void Interrupt(NodeContext& node) void Shutdown(NodeContext& node) { + static Mutex g_shutdown_mutex; + TRY_LOCK(g_shutdown_mutex, lock_shutdown); + if (!lock_shutdown) return; LogPrintf("%s: In progress...\n", __func__); - static RecursiveMutex cs_Shutdown; - TRY_LOCK(cs_Shutdown, lockShutdown); - if (!lockShutdown) - return; /// Note: Shutdown() must be able to handle cases in which initialization failed part of the way, /// for example if the data directory was found to be locked. @@ -216,8 +218,9 @@ void Shutdown(NodeContext& node) StopTorControl(); // After everything has been shut down, but before things get flushed, stop the - // CScheduler/checkqueue threadGroup + // CScheduler/checkqueue, threadGroup and load block thread. if (node.scheduler) node.scheduler->stop(); + if (g_load_block.joinable()) g_load_block.join(); threadGroup.interrupt_all(); threadGroup.join_all(); @@ -367,9 +370,10 @@ void SetupServerArgs(NodeContext& node) { assert(!node.args); node.args = &gArgs; + ArgsManager& argsman = *node.args; SetupHelpOptions(gArgs); - gArgs.AddArg("-help-debug", "Print help message with debugging options and exit", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); // server-only for now + argsman.AddArg("-help-debug", "Print help message with debugging options and exit", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); // server-only for now const auto defaultBaseParams = CreateBaseChainParams(CBaseChainParams::MAIN); const auto testnetBaseParams = CreateBaseChainParams(CBaseChainParams::TESTNET); @@ -384,112 +388,109 @@ void SetupServerArgs(NodeContext& node) // GUI args. These will be overwritten by SetupUIArgs for the GUI "-choosedatadir", "-lang=<lang>", "-min", "-resetguisettings", "-splash", "-uiplatform"}; - gArgs.AddArg("-version", "Print version and exit", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-version", "Print version and exit", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); #if HAVE_SYSTEM - gArgs.AddArg("-alertnotify=<cmd>", "Execute command when a relevant alert is received or we see a really long fork (%s in cmd is replaced by message)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-alertnotify=<cmd>", "Execute command when a relevant alert is received or we see a really long fork (%s in cmd is replaced by message)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); #endif - gArgs.AddArg("-assumevalid=<hex>", strprintf("If this block is in the chain assume that it and its ancestors are valid and potentially skip their script verification (0 to verify all, default: %s, testnet: %s)", defaultChainParams->GetConsensus().defaultAssumeValid.GetHex(), testnetChainParams->GetConsensus().defaultAssumeValid.GetHex()), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-blocksdir=<dir>", "Specify directory to hold blocks subdirectory for *.dat files (default: <datadir>)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-assumevalid=<hex>", strprintf("If this block is in the chain assume that it and its ancestors are valid and potentially skip their script verification (0 to verify all, default: %s, testnet: %s)", defaultChainParams->GetConsensus().defaultAssumeValid.GetHex(), testnetChainParams->GetConsensus().defaultAssumeValid.GetHex()), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-blocksdir=<dir>", "Specify directory to hold blocks subdirectory for *.dat files (default: <datadir>)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); #if HAVE_SYSTEM - gArgs.AddArg("-blocknotify=<cmd>", "Execute command when the best block changes (%s in cmd is replaced by block hash)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-blocknotify=<cmd>", "Execute command when the best block changes (%s in cmd is replaced by block hash)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); #endif - gArgs.AddArg("-blockreconstructionextratxn=<n>", strprintf("Extra transactions to keep in memory for compact block reconstructions (default: %u)", DEFAULT_BLOCK_RECONSTRUCTION_EXTRA_TXN), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-blocksonly", strprintf("Whether to reject transactions from network peers. Automatic broadcast and rebroadcast of any transactions from inbound peers is disabled, unless '-whitelistforcerelay' is '1', in which case whitelisted peers' transactions will be relayed. RPC transactions are not affected. (default: %u)", DEFAULT_BLOCKSONLY), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-conf=<file>", strprintf("Specify configuration file. Relative paths will be prefixed by datadir location. (default: %s)", BITCOIN_CONF_FILENAME), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-datadir=<dir>", "Specify data directory", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-dbbatchsize", strprintf("Maximum database write batch size in bytes (default: %u)", nDefaultDbBatchSize), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::OPTIONS); - gArgs.AddArg("-dbcache=<n>", strprintf("Maximum database cache size <n> MiB (%d to %d, default: %d). In addition, unused mempool memory is shared for this cache (see -maxmempool).", nMinDbCache, nMaxDbCache, nDefaultDbCache), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-debuglogfile=<file>", strprintf("Specify location of debug log file. Relative paths will be prefixed by a net-specific datadir location. (-nodebuglogfile to disable; default: %s)", DEFAULT_DEBUGLOGFILE), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-feefilter", strprintf("Tell other nodes to filter invs to us by our mempool min fee (default: %u)", DEFAULT_FEEFILTER), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::OPTIONS); - gArgs.AddArg("-includeconf=<file>", "Specify additional configuration file, relative to the -datadir path (only useable from configuration file, not command line)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-loadblock=<file>", "Imports blocks from external file on startup", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-maxmempool=<n>", strprintf("Keep the transaction memory pool below <n> megabytes (default: %u)", DEFAULT_MAX_MEMPOOL_SIZE), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-maxorphantx=<n>", strprintf("Keep at most <n> unconnectable transactions in memory (default: %u)", DEFAULT_MAX_ORPHAN_TRANSACTIONS), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-mempoolexpiry=<n>", strprintf("Do not keep transactions in the mempool longer than <n> hours (default: %u)", DEFAULT_MEMPOOL_EXPIRY), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-minimumchainwork=<hex>", strprintf("Minimum work assumed to exist on a valid chain in hex (default: %s, testnet: %s)", defaultChainParams->GetConsensus().nMinimumChainWork.GetHex(), testnetChainParams->GetConsensus().nMinimumChainWork.GetHex()), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::OPTIONS); - gArgs.AddArg("-par=<n>", strprintf("Set the number of script verification threads (%u to %d, 0 = auto, <0 = leave that many cores free, default: %d)", + argsman.AddArg("-blockreconstructionextratxn=<n>", strprintf("Extra transactions to keep in memory for compact block reconstructions (default: %u)", DEFAULT_BLOCK_RECONSTRUCTION_EXTRA_TXN), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-blocksonly", strprintf("Whether to reject transactions from network peers. Automatic broadcast and rebroadcast of any transactions from inbound peers is disabled, unless the peer has the 'forcerelay' permission. RPC transactions are not affected. (default: %u)", DEFAULT_BLOCKSONLY), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-conf=<file>", strprintf("Specify path to read-only configuration file. Relative paths will be prefixed by datadir location. (default: %s)", BITCOIN_CONF_FILENAME), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-datadir=<dir>", "Specify data directory", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-dbbatchsize", strprintf("Maximum database write batch size in bytes (default: %u)", nDefaultDbBatchSize), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::OPTIONS); + argsman.AddArg("-dbcache=<n>", strprintf("Maximum database cache size <n> MiB (%d to %d, default: %d). In addition, unused mempool memory is shared for this cache (see -maxmempool).", nMinDbCache, nMaxDbCache, nDefaultDbCache), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-debuglogfile=<file>", strprintf("Specify location of debug log file. Relative paths will be prefixed by a net-specific datadir location. (-nodebuglogfile to disable; default: %s)", DEFAULT_DEBUGLOGFILE), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-feefilter", strprintf("Tell other nodes to filter invs to us by our mempool min fee (default: %u)", DEFAULT_FEEFILTER), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::OPTIONS); + argsman.AddArg("-includeconf=<file>", "Specify additional configuration file, relative to the -datadir path (only useable from configuration file, not command line)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-loadblock=<file>", "Imports blocks from external file on startup", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-maxmempool=<n>", strprintf("Keep the transaction memory pool below <n> megabytes (default: %u)", DEFAULT_MAX_MEMPOOL_SIZE), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-maxorphantx=<n>", strprintf("Keep at most <n> unconnectable transactions in memory (default: %u)", DEFAULT_MAX_ORPHAN_TRANSACTIONS), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-mempoolexpiry=<n>", strprintf("Do not keep transactions in the mempool longer than <n> hours (default: %u)", DEFAULT_MEMPOOL_EXPIRY), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-minimumchainwork=<hex>", strprintf("Minimum work assumed to exist on a valid chain in hex (default: %s, testnet: %s)", defaultChainParams->GetConsensus().nMinimumChainWork.GetHex(), testnetChainParams->GetConsensus().nMinimumChainWork.GetHex()), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::OPTIONS); + argsman.AddArg("-par=<n>", strprintf("Set the number of script verification threads (%u to %d, 0 = auto, <0 = leave that many cores free, default: %d)", -GetNumCores(), MAX_SCRIPTCHECK_THREADS, DEFAULT_SCRIPTCHECK_THREADS), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-persistmempool", strprintf("Whether to save the mempool on shutdown and load on restart (default: %u)", DEFAULT_PERSIST_MEMPOOL), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-pid=<file>", strprintf("Specify pid file. Relative paths will be prefixed by a net-specific datadir location. (default: %s)", BITCOIN_PID_FILENAME), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-prune=<n>", strprintf("Reduce storage requirements by enabling pruning (deleting) of old blocks. This allows the pruneblockchain RPC to be called to delete specific blocks, and enables automatic pruning of old blocks if a target size in MiB is provided. This mode is incompatible with -txindex and -rescan. " + argsman.AddArg("-persistmempool", strprintf("Whether to save the mempool on shutdown and load on restart (default: %u)", DEFAULT_PERSIST_MEMPOOL), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-pid=<file>", strprintf("Specify pid file. Relative paths will be prefixed by a net-specific datadir location. (default: %s)", BITCOIN_PID_FILENAME), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-prune=<n>", strprintf("Reduce storage requirements by enabling pruning (deleting) of old blocks. This allows the pruneblockchain RPC to be called to delete specific blocks, and enables automatic pruning of old blocks if a target size in MiB is provided. This mode is incompatible with -txindex and -rescan. " "Warning: Reverting this setting requires re-downloading the entire blockchain. " "(default: 0 = disable pruning blocks, 1 = allow manual pruning via RPC, >=%u = automatically prune block files to stay under the specified target size in MiB)", MIN_DISK_SPACE_FOR_BLOCK_FILES / 1024 / 1024), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-reindex", "Rebuild chain state and block index from the blk*.dat files on disk", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-reindex-chainstate", "Rebuild chain state from the currently indexed blocks. When in pruning mode or if blocks on disk might be corrupted, use full -reindex instead.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-reindex", "Rebuild chain state and block index from the blk*.dat files on disk", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-reindex-chainstate", "Rebuild chain state from the currently indexed blocks. When in pruning mode or if blocks on disk might be corrupted, use full -reindex instead.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-settings=<file>", strprintf("Specify path to dynamic settings data file. Can be disabled with -nosettings. File is written at runtime and not meant to be edited by users (use %s instead for custom settings). Relative paths will be prefixed by datadir location. (default: %s)", BITCOIN_CONF_FILENAME, BITCOIN_SETTINGS_FILENAME), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); #ifndef WIN32 - gArgs.AddArg("-sysperms", "Create new files with system default permissions, instead of umask 077 (only effective with disabled wallet functionality)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-sysperms", "Create new files with system default permissions, instead of umask 077 (only effective with disabled wallet functionality)", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); #else hidden_args.emplace_back("-sysperms"); #endif - gArgs.AddArg("-txindex", strprintf("Maintain a full transaction index, used by the getrawtransaction rpc call (default: %u)", DEFAULT_TXINDEX), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-blockfilterindex=<type>", + argsman.AddArg("-txindex", strprintf("Maintain a full transaction index, used by the getrawtransaction rpc call (default: %u)", DEFAULT_TXINDEX), ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-blockfilterindex=<type>", strprintf("Maintain an index of compact filters by block (default: %s, values: %s).", DEFAULT_BLOCKFILTERINDEX, ListBlockFilterTypes()) + " If <type> is not supplied or if <type> = 1, indexes for all known types are enabled.", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); - gArgs.AddArg("-addnode=<ip>", "Add a node to connect to and attempt to keep the connection open (see the `addnode` RPC command help for more info). This option can be specified multiple times to add multiple nodes.", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::CONNECTION); - gArgs.AddArg("-asmap=<file>", strprintf("Specify asn mapping used for bucketing of the peers (default: %s). Relative paths will be prefixed by the net-specific datadir location.", DEFAULT_ASMAP_FILENAME), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-banscore=<n>", strprintf("Threshold for disconnecting misbehaving peers (default: %u)", DEFAULT_BANSCORE_THRESHOLD), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-bantime=<n>", strprintf("Number of seconds to keep misbehaving peers from reconnecting (default: %u)", DEFAULT_MISBEHAVING_BANTIME), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-bind=<addr>", "Bind to given address and always listen on it. Use [host]:port notation for IPv6", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::CONNECTION); - gArgs.AddArg("-connect=<ip>", "Connect only to the specified node; -noconnect disables automatic connections (the rules for this peer are the same as for -addnode). This option can be specified multiple times to connect to multiple nodes.", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::CONNECTION); - gArgs.AddArg("-discover", "Discover own IP addresses (default: 1 when listening and no -externalip or -proxy)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-dns", strprintf("Allow DNS lookups for -addnode, -seednode and -connect (default: %u)", DEFAULT_NAME_LOOKUP), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-dnsseed", "Query for peer addresses via DNS lookup, if low on addresses (default: 1 unless -connect used)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-externalip=<ip>", "Specify your own public address", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-forcednsseed", strprintf("Always query for peer addresses via DNS lookup (default: %u)", DEFAULT_FORCEDNSSEED), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-listen", "Accept connections from outside (default: 1 if no -proxy or -connect)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-listenonion", strprintf("Automatically create Tor hidden service (default: %d)", DEFAULT_LISTEN_ONION), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-maxconnections=<n>", strprintf("Maintain at most <n> connections to peers (default: %u)", DEFAULT_MAX_PEER_CONNECTIONS), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-maxreceivebuffer=<n>", strprintf("Maximum per-connection receive buffer, <n>*1000 bytes (default: %u)", DEFAULT_MAXRECEIVEBUFFER), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-maxsendbuffer=<n>", strprintf("Maximum per-connection send buffer, <n>*1000 bytes (default: %u)", DEFAULT_MAXSENDBUFFER), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-maxtimeadjustment", strprintf("Maximum allowed median peer time offset adjustment. Local perspective of time may be influenced by peers forward or backward by this amount. (default: %u seconds)", DEFAULT_MAX_TIME_ADJUSTMENT), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-maxuploadtarget=<n>", strprintf("Tries to keep outbound traffic under the given target (in MiB per 24h), 0 = no limit (default: %d)", DEFAULT_MAX_UPLOAD_TARGET), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-onion=<ip:port>", "Use separate SOCKS5 proxy to reach peers via Tor hidden services, set -noonion to disable (default: -proxy)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-onlynet=<net>", "Make outgoing connections only through network <net> (ipv4, ipv6 or onion). Incoming connections are not affected by this option. This option can be specified multiple times to allow multiple networks.", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-peerbloomfilters", strprintf("Support filtering of blocks and transaction with bloom filters (default: %u)", DEFAULT_PEERBLOOMFILTERS), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-peerblockfilters", strprintf("Serve compact block filters to peers per BIP 157 (default: %u)", DEFAULT_PEERBLOCKFILTERS), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-permitbaremultisig", strprintf("Relay non-P2SH multisig (default: %u)", DEFAULT_PERMIT_BAREMULTISIG), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-port=<port>", strprintf("Listen for connections on <port> (default: %u, testnet: %u, regtest: %u)", defaultChainParams->GetDefaultPort(), testnetChainParams->GetDefaultPort(), regtestChainParams->GetDefaultPort()), ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::CONNECTION); - gArgs.AddArg("-proxy=<ip:port>", "Connect through SOCKS5 proxy, set -noproxy to disable (default: disabled)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-proxyrandomize", strprintf("Randomize credentials for every proxy connection. This enables Tor stream isolation (default: %u)", DEFAULT_PROXYRANDOMIZE), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-seednode=<ip>", "Connect to a node to retrieve peer addresses, and disconnect. This option can be specified multiple times to connect to multiple nodes.", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-timeout=<n>", strprintf("Specify connection timeout in milliseconds (minimum: 1, default: %d)", DEFAULT_CONNECT_TIMEOUT), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-peertimeout=<n>", strprintf("Specify p2p connection timeout in seconds. This option determines the amount of time a peer may be inactive before the connection to it is dropped. (minimum: 1, default: %d)", DEFAULT_PEER_CONNECT_TIMEOUT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::CONNECTION); - gArgs.AddArg("-torcontrol=<ip>:<port>", strprintf("Tor control port to use if onion listening enabled (default: %s)", DEFAULT_TOR_CONTROL), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - gArgs.AddArg("-torpassword=<pass>", "Tor control port password (default: empty)", ArgsManager::ALLOW_ANY | ArgsManager::SENSITIVE, OptionsCategory::CONNECTION); + argsman.AddArg("-addnode=<ip>", "Add a node to connect to and attempt to keep the connection open (see the `addnode` RPC command help for more info). This option can be specified multiple times to add multiple nodes.", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::CONNECTION); + argsman.AddArg("-asmap=<file>", strprintf("Specify asn mapping used for bucketing of the peers (default: %s). Relative paths will be prefixed by the net-specific datadir location.", DEFAULT_ASMAP_FILENAME), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-bantime=<n>", strprintf("Default duration (in seconds) of manually configured bans (default: %u)", DEFAULT_MISBEHAVING_BANTIME), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-bind=<addr>", "Bind to given address and always listen on it. Use [host]:port notation for IPv6", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::CONNECTION); + argsman.AddArg("-connect=<ip>", "Connect only to the specified node; -noconnect disables automatic connections (the rules for this peer are the same as for -addnode). This option can be specified multiple times to connect to multiple nodes.", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::CONNECTION); + argsman.AddArg("-discover", "Discover own IP addresses (default: 1 when listening and no -externalip or -proxy)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-dns", strprintf("Allow DNS lookups for -addnode, -seednode and -connect (default: %u)", DEFAULT_NAME_LOOKUP), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-dnsseed", "Query for peer addresses via DNS lookup, if low on addresses (default: 1 unless -connect used)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-externalip=<ip>", "Specify your own public address", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-forcednsseed", strprintf("Always query for peer addresses via DNS lookup (default: %u)", DEFAULT_FORCEDNSSEED), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-listen", "Accept connections from outside (default: 1 if no -proxy or -connect)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-listenonion", strprintf("Automatically create Tor onion service (default: %d)", DEFAULT_LISTEN_ONION), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-maxconnections=<n>", strprintf("Maintain at most <n> connections to peers (default: %u)", DEFAULT_MAX_PEER_CONNECTIONS), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-maxreceivebuffer=<n>", strprintf("Maximum per-connection receive buffer, <n>*1000 bytes (default: %u)", DEFAULT_MAXRECEIVEBUFFER), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-maxsendbuffer=<n>", strprintf("Maximum per-connection send buffer, <n>*1000 bytes (default: %u)", DEFAULT_MAXSENDBUFFER), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-maxtimeadjustment", strprintf("Maximum allowed median peer time offset adjustment. Local perspective of time may be influenced by peers forward or backward by this amount. (default: %u seconds)", DEFAULT_MAX_TIME_ADJUSTMENT), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-maxuploadtarget=<n>", strprintf("Tries to keep outbound traffic under the given target (in MiB per 24h). Limit does not apply to peers with 'download' permission. 0 = no limit (default: %d)", DEFAULT_MAX_UPLOAD_TARGET), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-onion=<ip:port>", "Use separate SOCKS5 proxy to reach peers via Tor onion services, set -noonion to disable (default: -proxy)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-onlynet=<net>", "Make outgoing connections only through network <net> (ipv4, ipv6 or onion). Incoming connections are not affected by this option. This option can be specified multiple times to allow multiple networks.", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-peerbloomfilters", strprintf("Support filtering of blocks and transaction with bloom filters (default: %u)", DEFAULT_PEERBLOOMFILTERS), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-peerblockfilters", strprintf("Serve compact block filters to peers per BIP 157 (default: %u)", DEFAULT_PEERBLOCKFILTERS), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-permitbaremultisig", strprintf("Relay non-P2SH multisig (default: %u)", DEFAULT_PERMIT_BAREMULTISIG), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-port=<port>", strprintf("Listen for connections on <port> (default: %u, testnet: %u, regtest: %u)", defaultChainParams->GetDefaultPort(), testnetChainParams->GetDefaultPort(), regtestChainParams->GetDefaultPort()), ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::CONNECTION); + argsman.AddArg("-proxy=<ip:port>", "Connect through SOCKS5 proxy, set -noproxy to disable (default: disabled)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-proxyrandomize", strprintf("Randomize credentials for every proxy connection. This enables Tor stream isolation (default: %u)", DEFAULT_PROXYRANDOMIZE), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-seednode=<ip>", "Connect to a node to retrieve peer addresses, and disconnect. This option can be specified multiple times to connect to multiple nodes.", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-networkactive", "Enable all P2P network activity (default: 1). Can be changed by the setnetworkactive RPC command", ArgsManager::ALLOW_BOOL, OptionsCategory::CONNECTION); + argsman.AddArg("-timeout=<n>", strprintf("Specify connection timeout in milliseconds (minimum: 1, default: %d)", DEFAULT_CONNECT_TIMEOUT), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-peertimeout=<n>", strprintf("Specify p2p connection timeout in seconds. This option determines the amount of time a peer may be inactive before the connection to it is dropped. (minimum: 1, default: %d)", DEFAULT_PEER_CONNECT_TIMEOUT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::CONNECTION); + argsman.AddArg("-torcontrol=<ip>:<port>", strprintf("Tor control port to use if onion listening enabled (default: %s)", DEFAULT_TOR_CONTROL), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-torpassword=<pass>", "Tor control port password (default: empty)", ArgsManager::ALLOW_ANY | ArgsManager::SENSITIVE, OptionsCategory::CONNECTION); #ifdef USE_UPNP #if USE_UPNP - gArgs.AddArg("-upnp", "Use UPnP to map the listening port (default: 1 when listening and no -proxy)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-upnp", "Use UPnP to map the listening port (default: 1 when listening and no -proxy)", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); #else - gArgs.AddArg("-upnp", strprintf("Use UPnP to map the listening port (default: %u)", 0), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + argsman.AddArg("-upnp", strprintf("Use UPnP to map the listening port (default: %u)", 0), ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); #endif #else hidden_args.emplace_back("-upnp"); #endif - gArgs.AddArg("-whitebind=<[permissions@]addr>", "Bind to given address and whitelist peers connecting to it. " - "Use [host]:port notation for IPv6. Allowed permissions are bloomfilter (allow requesting BIP37 filtered blocks and transactions), " - "noban (do not ban for misbehavior), " - "forcerelay (relay transactions that are already in the mempool; implies relay), " - "relay (relay even in -blocksonly mode), " - "and mempool (allow requesting BIP35 mempool contents). " - "Specify multiple permissions separated by commas (default: noban,mempool,relay). Can be specified multiple times.", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - - gArgs.AddArg("-whitelist=<[permissions@]IP address or network>", "Whitelist peers connecting from the given IP address (e.g. 1.2.3.4) or " - "CIDR notated network(e.g. 1.2.3.0/24). Uses same permissions as " + argsman.AddArg("-whitebind=<[permissions@]addr>", "Bind to the given address and add permission flags to the peers connecting to it. " + "Use [host]:port notation for IPv6. Allowed permissions: " + Join(NET_PERMISSIONS_DOC, ", ") + ". " + "Specify multiple permissions separated by commas (default: download,noban,mempool,relay). Can be specified multiple times.", ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); + + argsman.AddArg("-whitelist=<[permissions@]IP address or network>", "Add permission flags to the peers connecting from the given IP address (e.g. 1.2.3.4) or " + "CIDR-notated network (e.g. 1.2.3.0/24). Uses the same permissions as " "-whitebind. Can be specified multiple times." , ArgsManager::ALLOW_ANY, OptionsCategory::CONNECTION); - g_wallet_init_interface.AddWalletOptions(); + g_wallet_init_interface.AddWalletOptions(argsman); #if ENABLE_ZMQ - gArgs.AddArg("-zmqpubhashblock=<address>", "Enable publish hash block in <address>", ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); - gArgs.AddArg("-zmqpubhashtx=<address>", "Enable publish hash transaction in <address>", ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); - gArgs.AddArg("-zmqpubrawblock=<address>", "Enable publish raw block in <address>", ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); - gArgs.AddArg("-zmqpubrawtx=<address>", "Enable publish raw transaction in <address>", ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); - gArgs.AddArg("-zmqpubhashblockhwm=<n>", strprintf("Set publish hash block outbound message high water mark (default: %d)", CZMQAbstractNotifier::DEFAULT_ZMQ_SNDHWM), ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); - gArgs.AddArg("-zmqpubhashtxhwm=<n>", strprintf("Set publish hash transaction outbound message high water mark (default: %d)", CZMQAbstractNotifier::DEFAULT_ZMQ_SNDHWM), ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); - gArgs.AddArg("-zmqpubrawblockhwm=<n>", strprintf("Set publish raw block outbound message high water mark (default: %d)", CZMQAbstractNotifier::DEFAULT_ZMQ_SNDHWM), ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); - gArgs.AddArg("-zmqpubrawtxhwm=<n>", strprintf("Set publish raw transaction outbound message high water mark (default: %d)", CZMQAbstractNotifier::DEFAULT_ZMQ_SNDHWM), ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); + argsman.AddArg("-zmqpubhashblock=<address>", "Enable publish hash block in <address>", ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); + argsman.AddArg("-zmqpubhashtx=<address>", "Enable publish hash transaction in <address>", ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); + argsman.AddArg("-zmqpubrawblock=<address>", "Enable publish raw block in <address>", ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); + argsman.AddArg("-zmqpubrawtx=<address>", "Enable publish raw transaction in <address>", ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); + argsman.AddArg("-zmqpubhashblockhwm=<n>", strprintf("Set publish hash block outbound message high water mark (default: %d)", CZMQAbstractNotifier::DEFAULT_ZMQ_SNDHWM), ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); + argsman.AddArg("-zmqpubhashtxhwm=<n>", strprintf("Set publish hash transaction outbound message high water mark (default: %d)", CZMQAbstractNotifier::DEFAULT_ZMQ_SNDHWM), ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); + argsman.AddArg("-zmqpubrawblockhwm=<n>", strprintf("Set publish raw block outbound message high water mark (default: %d)", CZMQAbstractNotifier::DEFAULT_ZMQ_SNDHWM), ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); + argsman.AddArg("-zmqpubrawtxhwm=<n>", strprintf("Set publish raw transaction outbound message high water mark (default: %d)", CZMQAbstractNotifier::DEFAULT_ZMQ_SNDHWM), ArgsManager::ALLOW_ANY, OptionsCategory::ZMQ); #else hidden_args.emplace_back("-zmqpubhashblock=<address>"); hidden_args.emplace_back("-zmqpubhashtx=<address>"); @@ -501,89 +502,82 @@ void SetupServerArgs(NodeContext& node) hidden_args.emplace_back("-zmqpubrawtxhwm=<n>"); #endif - gArgs.AddArg("-checkblocks=<n>", strprintf("How many blocks to check at startup (default: %u, 0 = all)", DEFAULT_CHECKBLOCKS), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-checklevel=<n>", strprintf("How thorough the block verification of -checkblocks is: " - "level 0 reads the blocks from disk, " - "level 1 verifies block validity, " - "level 2 verifies undo data, " - "level 3 checks disconnection of tip blocks, " - "and level 4 tries to reconnect the blocks, " - "each level includes the checks of the previous levels " - "(0-4, default: %u)", DEFAULT_CHECKLEVEL), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-checkblockindex", strprintf("Do a consistency check for the block tree, chainstate, and other validation data structures occasionally. (default: %u, regtest: %u)", defaultChainParams->DefaultConsistencyChecks(), regtestChainParams->DefaultConsistencyChecks()), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-checkmempool=<n>", strprintf("Run checks every <n> transactions (default: %u, regtest: %u)", defaultChainParams->DefaultConsistencyChecks(), regtestChainParams->DefaultConsistencyChecks()), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-checkpoints", strprintf("Enable rejection of any forks from the known historical chain until block 295000 (default: %u)", DEFAULT_CHECKPOINTS_ENABLED), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-deprecatedrpc=<method>", "Allows deprecated RPC method(s) to be used", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-dropmessagestest=<n>", "Randomly drop 1 of every <n> network messages", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-stopafterblockimport", strprintf("Stop running after importing blocks from disk (default: %u)", DEFAULT_STOPAFTERBLOCKIMPORT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-stopatheight", strprintf("Stop running after reaching the given height in the main chain (default: %u)", DEFAULT_STOPATHEIGHT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-limitancestorcount=<n>", strprintf("Do not accept transactions if number of in-mempool ancestors is <n> or more (default: %u)", DEFAULT_ANCESTOR_LIMIT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-limitancestorsize=<n>", strprintf("Do not accept transactions whose size with all in-mempool ancestors exceeds <n> kilobytes (default: %u)", DEFAULT_ANCESTOR_SIZE_LIMIT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-limitdescendantcount=<n>", strprintf("Do not accept transactions if any ancestor would have <n> or more in-mempool descendants (default: %u)", DEFAULT_DESCENDANT_LIMIT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-limitdescendantsize=<n>", strprintf("Do not accept transactions if any ancestor would have more than <n> kilobytes of in-mempool descendants (default: %u).", DEFAULT_DESCENDANT_SIZE_LIMIT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-addrmantest", "Allows to test address relay on localhost", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-debug=<category>", "Output debugging information (default: -nodebug, supplying <category> is optional). " + argsman.AddArg("-checkblocks=<n>", strprintf("How many blocks to check at startup (default: %u, 0 = all)", DEFAULT_CHECKBLOCKS), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-checklevel=<n>", strprintf("How thorough the block verification of -checkblocks is: %s (0-4, default: %u)", Join(CHECKLEVEL_DOC, ", "), DEFAULT_CHECKLEVEL), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-checkblockindex", strprintf("Do a consistency check for the block tree, chainstate, and other validation data structures occasionally. (default: %u, regtest: %u)", defaultChainParams->DefaultConsistencyChecks(), regtestChainParams->DefaultConsistencyChecks()), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-checkmempool=<n>", strprintf("Run checks every <n> transactions (default: %u, regtest: %u)", defaultChainParams->DefaultConsistencyChecks(), regtestChainParams->DefaultConsistencyChecks()), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-checkpoints", strprintf("Enable rejection of any forks from the known historical chain until block 295000 (default: %u)", DEFAULT_CHECKPOINTS_ENABLED), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-deprecatedrpc=<method>", "Allows deprecated RPC method(s) to be used", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-dropmessagestest=<n>", "Randomly drop 1 of every <n> network messages", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-stopafterblockimport", strprintf("Stop running after importing blocks from disk (default: %u)", DEFAULT_STOPAFTERBLOCKIMPORT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-stopatheight", strprintf("Stop running after reaching the given height in the main chain (default: %u)", DEFAULT_STOPATHEIGHT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-limitancestorcount=<n>", strprintf("Do not accept transactions if number of in-mempool ancestors is <n> or more (default: %u)", DEFAULT_ANCESTOR_LIMIT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-limitancestorsize=<n>", strprintf("Do not accept transactions whose size with all in-mempool ancestors exceeds <n> kilobytes (default: %u)", DEFAULT_ANCESTOR_SIZE_LIMIT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-limitdescendantcount=<n>", strprintf("Do not accept transactions if any ancestor would have <n> or more in-mempool descendants (default: %u)", DEFAULT_DESCENDANT_LIMIT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-limitdescendantsize=<n>", strprintf("Do not accept transactions if any ancestor would have more than <n> kilobytes of in-mempool descendants (default: %u).", DEFAULT_DESCENDANT_SIZE_LIMIT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-addrmantest", "Allows to test address relay on localhost", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-debug=<category>", "Output debugging information (default: -nodebug, supplying <category> is optional). " "If <category> is not supplied or if <category> = 1, output all debugging information. <category> can be: " + LogInstance().LogCategoriesString() + ".", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-debugexclude=<category>", strprintf("Exclude debugging information for a category. Can be used in conjunction with -debug=1 to output debug logs for all categories except one or more specified categories."), ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-logips", strprintf("Include IP addresses in debug output (default: %u)", DEFAULT_LOGIPS), ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-logtimestamps", strprintf("Prepend debug output with timestamp (default: %u)", DEFAULT_LOGTIMESTAMPS), ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-debugexclude=<category>", strprintf("Exclude debugging information for a category. Can be used in conjunction with -debug=1 to output debug logs for all categories except one or more specified categories."), ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-logips", strprintf("Include IP addresses in debug output (default: %u)", DEFAULT_LOGIPS), ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-logtimestamps", strprintf("Prepend debug output with timestamp (default: %u)", DEFAULT_LOGTIMESTAMPS), ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); #ifdef HAVE_THREAD_LOCAL - gArgs.AddArg("-logthreadnames", strprintf("Prepend debug output with name of the originating thread (only available on platforms supporting thread_local) (default: %u)", DEFAULT_LOGTHREADNAMES), ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-logthreadnames", strprintf("Prepend debug output with name of the originating thread (only available on platforms supporting thread_local) (default: %u)", DEFAULT_LOGTHREADNAMES), ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); #else hidden_args.emplace_back("-logthreadnames"); #endif - gArgs.AddArg("-logtimemicros", strprintf("Add microsecond precision to debug timestamps (default: %u)", DEFAULT_LOGTIMEMICROS), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-mocktime=<n>", "Replace actual time with " + UNIX_EPOCH_TIME + " (default: 0)", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-maxsigcachesize=<n>", strprintf("Limit sum of signature cache and script execution cache sizes to <n> MiB (default: %u)", DEFAULT_MAX_SIG_CACHE_SIZE), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-maxtipage=<n>", strprintf("Maximum tip age in seconds to consider node in initial block download (default: %u)", DEFAULT_MAX_TIP_AGE), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-printpriority", strprintf("Log transaction fee per kB when mining blocks (default: %u)", DEFAULT_PRINTPRIORITY), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-printtoconsole", "Send trace/debug info to console (default: 1 when no -daemon. To disable logging to file, set -nodebuglogfile)", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-shrinkdebugfile", "Shrink debug.log file on client startup (default: 1 when no -debug)", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-uacomment=<cmt>", "Append comment to the user agent string", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); - - SetupChainParamsBaseOptions(); - - gArgs.AddArg("-acceptnonstdtxn", strprintf("Relay and mine \"non-standard\" transactions (%sdefault: %u)", "testnet/regtest only; ", !testnetChainParams->RequireStandard()), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::NODE_RELAY); - gArgs.AddArg("-incrementalrelayfee=<amt>", strprintf("Fee rate (in %s/kB) used to define cost of relay, used for mempool limiting and BIP 125 replacement. (default: %s)", CURRENCY_UNIT, FormatMoney(DEFAULT_INCREMENTAL_RELAY_FEE)), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::NODE_RELAY); - gArgs.AddArg("-dustrelayfee=<amt>", strprintf("Fee rate (in %s/kB) used to define dust, the value of an output such that it will cost more than its value in fees at this fee rate to spend it. (default: %s)", CURRENCY_UNIT, FormatMoney(DUST_RELAY_TX_FEE)), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::NODE_RELAY); - gArgs.AddArg("-bytespersigop", strprintf("Equivalent bytes per sigop in transactions for relay and mining (default: %u)", DEFAULT_BYTES_PER_SIGOP), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); - gArgs.AddArg("-datacarrier", strprintf("Relay and mine data carrier transactions (default: %u)", DEFAULT_ACCEPT_DATACARRIER), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); - gArgs.AddArg("-datacarriersize", strprintf("Maximum size of data in data carrier transactions we relay and mine (default: %u)", MAX_OP_RETURN_RELAY), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); - gArgs.AddArg("-minrelaytxfee=<amt>", strprintf("Fees (in %s/kB) smaller than this are considered zero fee for relaying, mining and transaction creation (default: %s)", + argsman.AddArg("-logtimemicros", strprintf("Add microsecond precision to debug timestamps (default: %u)", DEFAULT_LOGTIMEMICROS), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-mocktime=<n>", "Replace actual time with " + UNIX_EPOCH_TIME + " (default: 0)", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-maxsigcachesize=<n>", strprintf("Limit sum of signature cache and script execution cache sizes to <n> MiB (default: %u)", DEFAULT_MAX_SIG_CACHE_SIZE), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-maxtipage=<n>", strprintf("Maximum tip age in seconds to consider node in initial block download (default: %u)", DEFAULT_MAX_TIP_AGE), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-printpriority", strprintf("Log transaction fee per kB when mining blocks (default: %u)", DEFAULT_PRINTPRIORITY), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-printtoconsole", "Send trace/debug info to console (default: 1 when no -daemon. To disable logging to file, set -nodebuglogfile)", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-shrinkdebugfile", "Shrink debug.log file on client startup (default: 1 when no -debug)", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); + argsman.AddArg("-uacomment=<cmt>", "Append comment to the user agent string", ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); + + SetupChainParamsBaseOptions(argsman); + + argsman.AddArg("-acceptnonstdtxn", strprintf("Relay and mine \"non-standard\" transactions (%sdefault: %u)", "testnet/regtest only; ", !testnetChainParams->RequireStandard()), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::NODE_RELAY); + argsman.AddArg("-incrementalrelayfee=<amt>", strprintf("Fee rate (in %s/kB) used to define cost of relay, used for mempool limiting and BIP 125 replacement. (default: %s)", CURRENCY_UNIT, FormatMoney(DEFAULT_INCREMENTAL_RELAY_FEE)), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::NODE_RELAY); + argsman.AddArg("-dustrelayfee=<amt>", strprintf("Fee rate (in %s/kB) used to define dust, the value of an output such that it will cost more than its value in fees at this fee rate to spend it. (default: %s)", CURRENCY_UNIT, FormatMoney(DUST_RELAY_TX_FEE)), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::NODE_RELAY); + argsman.AddArg("-bytespersigop", strprintf("Equivalent bytes per sigop in transactions for relay and mining (default: %u)", DEFAULT_BYTES_PER_SIGOP), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); + argsman.AddArg("-datacarrier", strprintf("Relay and mine data carrier transactions (default: %u)", DEFAULT_ACCEPT_DATACARRIER), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); + argsman.AddArg("-datacarriersize", strprintf("Maximum size of data in data carrier transactions we relay and mine (default: %u)", MAX_OP_RETURN_RELAY), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); + argsman.AddArg("-minrelaytxfee=<amt>", strprintf("Fees (in %s/kB) smaller than this are considered zero fee for relaying, mining and transaction creation (default: %s)", CURRENCY_UNIT, FormatMoney(DEFAULT_MIN_RELAY_TX_FEE)), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); - gArgs.AddArg("-whitelistforcerelay", strprintf("Add 'forcerelay' permission to whitelisted inbound peers with default permissions. This will relay transactions even if the transactions were already in the mempool. (default: %d)", DEFAULT_WHITELISTFORCERELAY), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); - gArgs.AddArg("-whitelistrelay", strprintf("Add 'relay' permission to whitelisted inbound peers with default permissions. This will accept relayed transactions even when not relaying transactions (default: %d)", DEFAULT_WHITELISTRELAY), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); - - - gArgs.AddArg("-blockmaxweight=<n>", strprintf("Set maximum BIP141 block weight (default: %d)", DEFAULT_BLOCK_MAX_WEIGHT), ArgsManager::ALLOW_ANY, OptionsCategory::BLOCK_CREATION); - gArgs.AddArg("-blockmintxfee=<amt>", strprintf("Set lowest fee rate (in %s/kB) for transactions to be included in block creation. (default: %s)", CURRENCY_UNIT, FormatMoney(DEFAULT_BLOCK_MIN_TX_FEE)), ArgsManager::ALLOW_ANY, OptionsCategory::BLOCK_CREATION); - gArgs.AddArg("-blockversion=<n>", "Override block version to test forking scenarios", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::BLOCK_CREATION); - - gArgs.AddArg("-rest", strprintf("Accept public REST requests (default: %u)", DEFAULT_REST_ENABLE), ArgsManager::ALLOW_ANY, OptionsCategory::RPC); - gArgs.AddArg("-rpcallowip=<ip>", "Allow JSON-RPC connections from specified source. Valid for <ip> are a single IP (e.g. 1.2.3.4), a network/netmask (e.g. 1.2.3.4/255.255.255.0) or a network/CIDR (e.g. 1.2.3.4/24). This option can be specified multiple times", ArgsManager::ALLOW_ANY, OptionsCategory::RPC); - gArgs.AddArg("-rpcauth=<userpw>", "Username and HMAC-SHA-256 hashed password for JSON-RPC connections. The field <userpw> comes in the format: <USERNAME>:<SALT>$<HASH>. A canonical python script is included in share/rpcauth. The client then connects normally using the rpcuser=<USERNAME>/rpcpassword=<PASSWORD> pair of arguments. This option can be specified multiple times", ArgsManager::ALLOW_ANY | ArgsManager::SENSITIVE, OptionsCategory::RPC); - gArgs.AddArg("-rpcbind=<addr>[:port]", "Bind to given address to listen for JSON-RPC connections. Do not expose the RPC server to untrusted networks such as the public internet! This option is ignored unless -rpcallowip is also passed. Port is optional and overrides -rpcport. Use [host]:port notation for IPv6. This option can be specified multiple times (default: 127.0.0.1 and ::1 i.e., localhost)", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY | ArgsManager::SENSITIVE, OptionsCategory::RPC); - gArgs.AddArg("-rpccookiefile=<loc>", "Location of the auth cookie. Relative paths will be prefixed by a net-specific datadir location. (default: data dir)", ArgsManager::ALLOW_ANY, OptionsCategory::RPC); - gArgs.AddArg("-rpcpassword=<pw>", "Password for JSON-RPC connections", ArgsManager::ALLOW_ANY | ArgsManager::SENSITIVE, OptionsCategory::RPC); - gArgs.AddArg("-rpcport=<port>", strprintf("Listen for JSON-RPC connections on <port> (default: %u, testnet: %u, regtest: %u)", defaultBaseParams->RPCPort(), testnetBaseParams->RPCPort(), regtestBaseParams->RPCPort()), ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::RPC); - gArgs.AddArg("-rpcserialversion", strprintf("Sets the serialization of raw transaction or block hex returned in non-verbose mode, non-segwit(0) or segwit(1) (default: %d)", DEFAULT_RPC_SERIALIZE_VERSION), ArgsManager::ALLOW_ANY, OptionsCategory::RPC); - gArgs.AddArg("-rpcservertimeout=<n>", strprintf("Timeout during HTTP requests (default: %d)", DEFAULT_HTTP_SERVER_TIMEOUT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::RPC); - gArgs.AddArg("-rpcthreads=<n>", strprintf("Set the number of threads to service RPC calls (default: %d)", DEFAULT_HTTP_THREADS), ArgsManager::ALLOW_ANY, OptionsCategory::RPC); - gArgs.AddArg("-rpcuser=<user>", "Username for JSON-RPC connections", ArgsManager::ALLOW_ANY | ArgsManager::SENSITIVE, OptionsCategory::RPC); - gArgs.AddArg("-rpcwhitelist=<whitelist>", "Set a whitelist to filter incoming RPC calls for a specific user. The field <whitelist> comes in the format: <USERNAME>:<rpc 1>,<rpc 2>,...,<rpc n>. If multiple whitelists are set for a given user, they are set-intersected. See -rpcwhitelistdefault documentation for information on default whitelist behavior.", ArgsManager::ALLOW_ANY, OptionsCategory::RPC); - gArgs.AddArg("-rpcwhitelistdefault", "Sets default behavior for rpc whitelisting. Unless rpcwhitelistdefault is set to 0, if any -rpcwhitelist is set, the rpc server acts as if all rpc users are subject to empty-unless-otherwise-specified whitelists. If rpcwhitelistdefault is set to 1 and no -rpcwhitelist is set, rpc server acts as if all rpc users are subject to empty whitelists.", ArgsManager::ALLOW_BOOL, OptionsCategory::RPC); - gArgs.AddArg("-rpcworkqueue=<n>", strprintf("Set the depth of the work queue to service RPC calls (default: %d)", DEFAULT_HTTP_WORKQUEUE), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::RPC); - gArgs.AddArg("-server", "Accept command line and JSON-RPC commands", ArgsManager::ALLOW_ANY, OptionsCategory::RPC); + argsman.AddArg("-whitelistforcerelay", strprintf("Add 'forcerelay' permission to whitelisted inbound peers with default permissions. This will relay transactions even if the transactions were already in the mempool. (default: %d)", DEFAULT_WHITELISTFORCERELAY), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); + argsman.AddArg("-whitelistrelay", strprintf("Add 'relay' permission to whitelisted inbound peers with default permissions. This will accept relayed transactions even when not relaying transactions (default: %d)", DEFAULT_WHITELISTRELAY), ArgsManager::ALLOW_ANY, OptionsCategory::NODE_RELAY); + + + argsman.AddArg("-blockmaxweight=<n>", strprintf("Set maximum BIP141 block weight (default: %d)", DEFAULT_BLOCK_MAX_WEIGHT), ArgsManager::ALLOW_ANY, OptionsCategory::BLOCK_CREATION); + argsman.AddArg("-blockmintxfee=<amt>", strprintf("Set lowest fee rate (in %s/kB) for transactions to be included in block creation. (default: %s)", CURRENCY_UNIT, FormatMoney(DEFAULT_BLOCK_MIN_TX_FEE)), ArgsManager::ALLOW_ANY, OptionsCategory::BLOCK_CREATION); + argsman.AddArg("-blockversion=<n>", "Override block version to test forking scenarios", ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::BLOCK_CREATION); + + argsman.AddArg("-rest", strprintf("Accept public REST requests (default: %u)", DEFAULT_REST_ENABLE), ArgsManager::ALLOW_ANY, OptionsCategory::RPC); + argsman.AddArg("-rpcallowip=<ip>", "Allow JSON-RPC connections from specified source. Valid for <ip> are a single IP (e.g. 1.2.3.4), a network/netmask (e.g. 1.2.3.4/255.255.255.0) or a network/CIDR (e.g. 1.2.3.4/24). This option can be specified multiple times", ArgsManager::ALLOW_ANY, OptionsCategory::RPC); + argsman.AddArg("-rpcauth=<userpw>", "Username and HMAC-SHA-256 hashed password for JSON-RPC connections. The field <userpw> comes in the format: <USERNAME>:<SALT>$<HASH>. A canonical python script is included in share/rpcauth. The client then connects normally using the rpcuser=<USERNAME>/rpcpassword=<PASSWORD> pair of arguments. This option can be specified multiple times", ArgsManager::ALLOW_ANY | ArgsManager::SENSITIVE, OptionsCategory::RPC); + argsman.AddArg("-rpcbind=<addr>[:port]", "Bind to given address to listen for JSON-RPC connections. Do not expose the RPC server to untrusted networks such as the public internet! This option is ignored unless -rpcallowip is also passed. Port is optional and overrides -rpcport. Use [host]:port notation for IPv6. This option can be specified multiple times (default: 127.0.0.1 and ::1 i.e., localhost)", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY | ArgsManager::SENSITIVE, OptionsCategory::RPC); + argsman.AddArg("-rpccookiefile=<loc>", "Location of the auth cookie. Relative paths will be prefixed by a net-specific datadir location. (default: data dir)", ArgsManager::ALLOW_ANY, OptionsCategory::RPC); + argsman.AddArg("-rpcpassword=<pw>", "Password for JSON-RPC connections", ArgsManager::ALLOW_ANY | ArgsManager::SENSITIVE, OptionsCategory::RPC); + argsman.AddArg("-rpcport=<port>", strprintf("Listen for JSON-RPC connections on <port> (default: %u, testnet: %u, regtest: %u)", defaultBaseParams->RPCPort(), testnetBaseParams->RPCPort(), regtestBaseParams->RPCPort()), ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::RPC); + argsman.AddArg("-rpcserialversion", strprintf("Sets the serialization of raw transaction or block hex returned in non-verbose mode, non-segwit(0) or segwit(1) (default: %d)", DEFAULT_RPC_SERIALIZE_VERSION), ArgsManager::ALLOW_ANY, OptionsCategory::RPC); + argsman.AddArg("-rpcservertimeout=<n>", strprintf("Timeout during HTTP requests (default: %d)", DEFAULT_HTTP_SERVER_TIMEOUT), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::RPC); + argsman.AddArg("-rpcthreads=<n>", strprintf("Set the number of threads to service RPC calls (default: %d)", DEFAULT_HTTP_THREADS), ArgsManager::ALLOW_ANY, OptionsCategory::RPC); + argsman.AddArg("-rpcuser=<user>", "Username for JSON-RPC connections", ArgsManager::ALLOW_ANY | ArgsManager::SENSITIVE, OptionsCategory::RPC); + argsman.AddArg("-rpcwhitelist=<whitelist>", "Set a whitelist to filter incoming RPC calls for a specific user. The field <whitelist> comes in the format: <USERNAME>:<rpc 1>,<rpc 2>,...,<rpc n>. If multiple whitelists are set for a given user, they are set-intersected. See -rpcwhitelistdefault documentation for information on default whitelist behavior.", ArgsManager::ALLOW_ANY, OptionsCategory::RPC); + argsman.AddArg("-rpcwhitelistdefault", "Sets default behavior for rpc whitelisting. Unless rpcwhitelistdefault is set to 0, if any -rpcwhitelist is set, the rpc server acts as if all rpc users are subject to empty-unless-otherwise-specified whitelists. If rpcwhitelistdefault is set to 1 and no -rpcwhitelist is set, rpc server acts as if all rpc users are subject to empty whitelists.", ArgsManager::ALLOW_BOOL, OptionsCategory::RPC); + argsman.AddArg("-rpcworkqueue=<n>", strprintf("Set the depth of the work queue to service RPC calls (default: %d)", DEFAULT_HTTP_WORKQUEUE), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::RPC); + argsman.AddArg("-server", "Accept command line and JSON-RPC commands", ArgsManager::ALLOW_ANY, OptionsCategory::RPC); #if HAVE_DECL_DAEMON - gArgs.AddArg("-daemon", "Run in the background as a daemon and accept commands", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); + argsman.AddArg("-daemon", "Run in the background as a daemon and accept commands", ArgsManager::ALLOW_ANY, OptionsCategory::OPTIONS); #else hidden_args.emplace_back("-daemon"); #endif // Add the hidden options - gArgs.AddHiddenArgs(hidden_args); + argsman.AddHiddenArgs(hidden_args); } std::string LicenseInfo() @@ -693,7 +687,6 @@ static void CleanupBlockRevFiles() static void ThreadImport(ChainstateManager& chainman, std::vector<fs::path> vImportFiles) { const CChainParams& chainparams = Params(); - util::ThreadRename("loadblk"); ScheduleBatchPriority(); { @@ -785,13 +778,14 @@ static bool InitSanityCheck() return true; } -static bool AppInitServers(const util::Ref& context) +static bool AppInitServers(const util::Ref& context, NodeContext& node) { RPCServer::OnStarted(&OnRPCStarted); RPCServer::OnStopped(&OnRPCStopped); if (!InitHTTPServer()) return false; StartRPC(); + node.rpc_interruption_point = RpcInterruptionPoint; if (!StartHTTPRPC(context)) return false; if (gArgs.GetBoolArg("-rest", DEFAULT_REST_ENABLE)) StartREST(context); @@ -964,17 +958,27 @@ bool AppInitParameterInteraction() // also see: InitParameterInteraction() - // Warn if network-specific options (-addnode, -connect, etc) are + // Error if network-specific options (-addnode, -connect, etc) are // specified in default section of config file, but not overridden // on the command line or in this network's section of the config file. std::string network = gArgs.GetChainName(); + bilingual_str errors; for (const auto& arg : gArgs.GetUnsuitableSectionOnlyArgs()) { - return InitError(strprintf(_("Config setting for %s only applied on %s network when in [%s] section."), arg, network, network)); + errors += strprintf(_("Config setting for %s only applied on %s network when in [%s] section.") + Untranslated("\n"), arg, network, network); + } + + if (!errors.empty()) { + return InitError(errors); } // Warn if unrecognized section name are present in the config file. + bilingual_str warnings; for (const auto& section : gArgs.GetUnrecognizedSections()) { - InitWarning(strprintf(Untranslated("%s:%i ") + _("Section [%s] is not recognized."), section.m_file, section.m_line, section.m_name)); + warnings += strprintf(Untranslated("%s:%i ") + _("Section [%s] is not recognized.") + Untranslated("\n"), section.m_file, section.m_line, section.m_name); + } + + if (!warnings.empty()) { + InitWarning(warnings); } if (!fs::is_directory(GetBlocksDir())) { @@ -1241,7 +1245,7 @@ bool AppInitLockDataDirectory() return true; } -bool AppInitMain(const util::Ref& context, NodeContext& node) +bool AppInitMain(const util::Ref& context, NodeContext& node, interfaces::BlockAndHeaderTipInfo* tip_info) { const CChainParams& chainparams = Params(); // ********************************************************* Step 4a: application initialization @@ -1320,8 +1324,7 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) node.scheduler = MakeUnique<CScheduler>(); // Start the lightweight task scheduler thread - CScheduler::Function serviceLoop = [&node]{ node.scheduler->serviceQueue(); }; - threadGroup.create_thread(std::bind(&TraceThread<CScheduler::Function>, "scheduler", serviceLoop)); + threadGroup.create_thread([&] { TraceThread("scheduler", [&] { node.scheduler->serviceQueue(); }); }); // Gather some entropy once per minute. node.scheduler->scheduleEvery([]{ @@ -1355,7 +1358,7 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) if (gArgs.GetBoolArg("-server", false)) { uiInterface.InitMessage_connect(SetRPCWarmupStatus); - if (!AppInitServers(context)) + if (!AppInitServers(context, node)) return InitError(_("Unable to start HTTP server. See debug log for details.")); } @@ -1375,16 +1378,16 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) assert(!node.banman); node.banman = MakeUnique<BanMan>(GetDataDir() / "banlist.dat", &uiInterface, gArgs.GetArg("-bantime", DEFAULT_MISBEHAVING_BANTIME)); assert(!node.connman); - node.connman = std::unique_ptr<CConnman>(new CConnman(GetRand(std::numeric_limits<uint64_t>::max()), GetRand(std::numeric_limits<uint64_t>::max()))); + node.connman = MakeUnique<CConnman>(GetRand(std::numeric_limits<uint64_t>::max()), GetRand(std::numeric_limits<uint64_t>::max()), gArgs.GetBoolArg("-networkactive", true)); // Make mempool generally available in the node context. For example the connection manager, wallet, or RPC threads, // which are all started after this, may use it from the node context. assert(!node.mempool); node.mempool = &::mempool; assert(!node.chainman); node.chainman = &g_chainman; - ChainstateManager& chainman = EnsureChainman(node); + ChainstateManager& chainman = *Assert(node.chainman); - node.peer_logic.reset(new PeerLogicValidation(node.connman.get(), node.banman.get(), *node.scheduler, *node.chainman, *node.mempool)); + node.peer_logic.reset(new PeerLogicValidation(node.connman.get(), node.banman.get(), *node.scheduler, chainman, *node.mempool)); RegisterValidationInterface(node.peer_logic.get()); // sanitize comments per BIP-0014, format user agent and check total size @@ -1470,7 +1473,7 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) if (Lookup(strAddr, addrLocal, GetListenPort(), fNameLookup) && addrLocal.IsValid()) AddLocal(addrLocal, LOCAL_MANUAL); else - return InitError(Untranslated(ResolveErrMsg("externalip", strAddr))); + return InitError(ResolveErrMsg("externalip", strAddr)); } // Read asmap file if configured @@ -1535,7 +1538,7 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) int64_t nCoinDBCache = std::min(nTotalCache / 2, (nTotalCache / 4) + (1 << 23)); // use 25%-50% of the remainder for disk cache nCoinDBCache = std::min(nCoinDBCache, nMaxCoinsDBCache << 20); // cap total coins db cache nTotalCache -= nCoinDBCache; - nCoinCacheUsage = nTotalCache; // the rest goes to in-memory cache + int64_t nCoinCacheUsage = nTotalCache; // the rest goes to in-memory cache int64_t nMempoolSizeMax = gArgs.GetArg("-maxmempool", DEFAULT_MAX_MEMPOOL_SIZE) * 1000000; LogPrintf("Cache configuration:\n"); LogPrintf("* Using %.1f MiB for block index database\n", nBlockTreeDBCache * (1.0 / 1024 / 1024)); @@ -1564,7 +1567,10 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) try { LOCK(cs_main); chainman.InitializeChainstate(); - UnloadBlockIndex(); + chainman.m_total_coinstip_cache = nCoinCacheUsage; + chainman.m_total_coinsdb_cache = nCoinDBCache; + + UnloadBlockIndex(node.mempool); // new CBlockTreeDB tries to delete the existing file, which // fails if it's still open from the previous loop. Close it first: @@ -1592,7 +1598,7 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) // If the loaded chain has a wrong genesis, bail out immediately // (we're likely using a testnet datadir, or the other way around). - if (!::BlockIndex().empty() && + if (!chainman.BlockIndex().empty() && !LookupBlockIndex(chainparams.GetConsensus().hashGenesisBlock)) { return InitError(_("Incorrect or no genesis block found. Wrong datadir for network?")); } @@ -1647,7 +1653,7 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) } // The on-disk coinsdb is now in a good state, create the cache - chainstate->InitCoinsCache(); + chainstate->InitCoinsCache(nCoinCacheUsage); assert(chainstate->CanFlushToDisk()); if (!is_coinsview_empty(chainstate)) { @@ -1847,7 +1853,7 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) vImportFiles.push_back(strFile); } - threadGroup.create_thread([=, &chainman] { ThreadImport(chainman, vImportFiles); }); + g_load_block = std::thread(&TraceThread<std::function<void()>>, "loadblk", [=, &chainman]{ ThreadImport(chainman, vImportFiles); }); // Wait for genesis block to be processed { @@ -1872,8 +1878,17 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) //// debug print { LOCK(cs_main); - LogPrintf("block tree size = %u\n", ::BlockIndex().size()); - chain_active_height = ::ChainActive().Height(); + LogPrintf("block tree size = %u\n", chainman.BlockIndex().size()); + chain_active_height = chainman.ActiveChain().Height(); + if (tip_info) { + tip_info->block_height = chain_active_height; + tip_info->block_time = chainman.ActiveChain().Tip() ? chainman.ActiveChain().Tip()->GetBlockTime() : Params().GenesisBlock().GetBlockTime(); + tip_info->verification_progress = GuessVerificationProgress(Params().TxData(), chainman.ActiveChain().Tip()); + } + if (tip_info && ::pindexBestHeader) { + tip_info->header_height = ::pindexBestHeader->nHeight; + tip_info->header_time = ::pindexBestHeader->GetBlockTime(); + } } LogPrintf("nBestHeight = %d\n", chain_active_height); @@ -1891,7 +1906,7 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) connOptions.nLocalServices = nLocalServices; connOptions.nMaxConnections = nMaxConnections; connOptions.m_max_outbound_full_relay = std::min(MAX_OUTBOUND_FULL_RELAY_CONNECTIONS, connOptions.nMaxConnections); - connOptions.m_max_outbound_block_relay = std::min(MAX_BLOCKS_ONLY_CONNECTIONS, connOptions.nMaxConnections-connOptions.m_max_outbound_full_relay); + connOptions.m_max_outbound_block_relay = std::min(MAX_BLOCK_RELAY_ONLY_CONNECTIONS, connOptions.nMaxConnections-connOptions.m_max_outbound_full_relay); connOptions.nMaxAddnode = MAX_ADDNODE_CONNECTIONS; connOptions.nMaxFeeler = MAX_FEELER_CONNECTIONS; connOptions.nBestHeight = chain_active_height; @@ -1909,21 +1924,21 @@ bool AppInitMain(const util::Ref& context, NodeContext& node) for (const std::string& strBind : gArgs.GetArgs("-bind")) { CService addrBind; if (!Lookup(strBind, addrBind, GetListenPort(), false)) { - return InitError(Untranslated(ResolveErrMsg("bind", strBind))); + return InitError(ResolveErrMsg("bind", strBind)); } connOptions.vBinds.push_back(addrBind); } for (const std::string& strBind : gArgs.GetArgs("-whitebind")) { NetWhitebindPermissions whitebind; - std::string error; - if (!NetWhitebindPermissions::TryParse(strBind, whitebind, error)) return InitError(Untranslated(error)); + bilingual_str error; + if (!NetWhitebindPermissions::TryParse(strBind, whitebind, error)) return InitError(error); connOptions.vWhiteBinds.push_back(whitebind); } for (const auto& net : gArgs.GetArgs("-whitelist")) { NetWhitelistPermissions subnet; - std::string error; - if (!NetWhitelistPermissions::TryParse(net, subnet, error)) return InitError(Untranslated(error)); + bilingual_str error; + if (!NetWhitelistPermissions::TryParse(net, subnet, error)) return InitError(error); connOptions.vWhitelistedRange.push_back(subnet); } diff --git a/src/init.h b/src/init.h index 33fe96e8ea..20008ba5be 100644 --- a/src/init.h +++ b/src/init.h @@ -11,6 +11,9 @@ #include <util/system.h> struct NodeContext; +namespace interfaces { +struct BlockAndHeaderTipInfo; +} namespace boost { class thread_group; } // namespace boost @@ -54,7 +57,7 @@ bool AppInitLockDataDirectory(); * @note This should only be done after daemonization. Call Shutdown() if this function fails. * @pre Parameters should be parsed and config file should be read, AppInitLockDataDirectory should have been called. */ -bool AppInitMain(const util::Ref& context, NodeContext& node); +bool AppInitMain(const util::Ref& context, NodeContext& node, interfaces::BlockAndHeaderTipInfo* tip_info = nullptr); /** * Register all arguments with the ArgsManager diff --git a/src/interfaces/chain.cpp b/src/interfaces/chain.cpp index d8e459a8e8..d49e4454af 100644 --- a/src/interfaces/chain.cpp +++ b/src/interfaces/chain.cpp @@ -13,6 +13,7 @@ #include <node/coin.h> #include <node/context.h> #include <node/transaction.h> +#include <node/ui_interface.h> #include <policy/fees.h> #include <policy/policy.h> #include <policy/rbf.h> @@ -25,7 +26,6 @@ #include <sync.h> #include <timedata.h> #include <txmempool.h> -#include <ui_interface.h> #include <uint256.h> #include <univalue.h> #include <util/system.h> @@ -63,9 +63,9 @@ public: { m_notifications->transactionAddedToMempool(tx); } - void TransactionRemovedFromMempool(const CTransactionRef& tx) override + void TransactionRemovedFromMempool(const CTransactionRef& tx, MemPoolRemovalReason reason) override { - m_notifications->transactionRemovedFromMempool(tx); + m_notifications->transactionRemovedFromMempool(tx, reason); } void BlockConnected(const std::shared_ptr<const CBlock>& block, const CBlockIndex* index) override { diff --git a/src/interfaces/chain.h b/src/interfaces/chain.h index 7dfc77db7b..bbeb0fa801 100644 --- a/src/interfaces/chain.h +++ b/src/interfaces/chain.h @@ -8,18 +8,21 @@ #include <optional.h> // For Optional and nullopt #include <primitives/transaction.h> // For CTransactionRef +#include <functional> #include <memory> #include <stddef.h> #include <stdint.h> #include <string> #include <vector> +class ArgsManager; class CBlock; class CFeeRate; class CRPCCommand; class CScheduler; class Coin; class uint256; +enum class MemPoolRemovalReason; enum class RBFTransactionState; struct bilingual_str; struct CBlockLocator; @@ -239,7 +242,7 @@ public: public: virtual ~Notifications() {} virtual void transactionAddedToMempool(const CTransactionRef& tx) {} - virtual void transactionRemovedFromMempool(const CTransactionRef& ptx) {} + virtual void transactionRemovedFromMempool(const CTransactionRef& tx, MemPoolRemovalReason reason) {} virtual void blockConnected(const CBlock& block, int height) {} virtual void blockDisconnected(const CBlock& block, int height) {} virtual void updatedBlockTip() {} @@ -320,7 +323,7 @@ std::unique_ptr<Chain> MakeChain(NodeContext& node); //! analysis, or fee estimation. These clients need to expose their own //! MakeXXXClient functions returning their implementations of the ChainClient //! interface. -std::unique_ptr<ChainClient> MakeWalletClient(Chain& chain, std::vector<std::string> wallet_filenames); +std::unique_ptr<ChainClient> MakeWalletClient(Chain& chain, ArgsManager& args, std::vector<std::string> wallet_filenames); } // namespace interfaces diff --git a/src/interfaces/node.cpp b/src/interfaces/node.cpp index bd1b36fea5..21400d00f8 100644 --- a/src/interfaces/node.cpp +++ b/src/interfaces/node.cpp @@ -17,6 +17,7 @@ #include <netaddress.h> #include <netbase.h> #include <node/context.h> +#include <node/ui_interface.h> #include <policy/feerate.h> #include <policy/fees.h> #include <policy/settings.h> @@ -26,7 +27,6 @@ #include <support/allocators/secure.h> #include <sync.h> #include <txmempool.h> -#include <ui_interface.h> #include <util/ref.h> #include <util/system.h> #include <util/translation.h> @@ -56,7 +56,8 @@ namespace { class NodeImpl : public Node { public: - void initError(const std::string& message) override { InitError(Untranslated(message)); } + NodeImpl(NodeContext* context) { setContext(context); } + void initError(const bilingual_str& message) override { InitError(message); } bool parseParameters(int argc, const char* const argv[], std::string& error) override { return gArgs.ParseParameters(argc, argv, error); @@ -66,27 +67,28 @@ public: bool softSetArg(const std::string& arg, const std::string& value) override { return gArgs.SoftSetArg(arg, value); } bool softSetBoolArg(const std::string& arg, bool value) override { return gArgs.SoftSetBoolArg(arg, value); } void selectParams(const std::string& network) override { SelectParams(network); } + bool initSettings(std::string& error) override { return gArgs.InitSettings(error); } uint64_t getAssumedBlockchainSize() override { return Params().AssumedBlockchainSize(); } uint64_t getAssumedChainStateSize() override { return Params().AssumedChainStateSize(); } std::string getNetwork() override { return Params().NetworkIDString(); } void initLogging() override { InitLogging(); } void initParameterInteraction() override { InitParameterInteraction(); } - std::string getWarnings() override { return GetWarnings(true); } + bilingual_str getWarnings() override { return GetWarnings(true); } uint32_t getLogCategories() override { return LogInstance().GetCategoryMask(); } bool baseInitialize() override { return AppInitBasicSetup() && AppInitParameterInteraction() && AppInitSanityChecks() && AppInitLockDataDirectory(); } - bool appInitMain() override + bool appInitMain(interfaces::BlockAndHeaderTipInfo* tip_info) override { - m_context.chain = MakeChain(m_context); - return AppInitMain(m_context_ref, m_context); + m_context->chain = MakeChain(*m_context); + return AppInitMain(m_context_ref, *m_context, tip_info); } void appShutdown() override { - Interrupt(m_context); - Shutdown(m_context); + Interrupt(*m_context); + Shutdown(*m_context); } void startShutdown() override { @@ -107,19 +109,19 @@ public: StopMapPort(); } } - void setupServerArgs() override { return SetupServerArgs(m_context); } + void setupServerArgs() override { return SetupServerArgs(*m_context); } bool getProxy(Network net, proxyType& proxy_info) override { return GetProxy(net, proxy_info); } size_t getNodeCount(CConnman::NumConnections flags) override { - return m_context.connman ? m_context.connman->GetNodeCount(flags) : 0; + return m_context->connman ? m_context->connman->GetNodeCount(flags) : 0; } bool getNodesStats(NodesStats& stats) override { stats.clear(); - if (m_context.connman) { + if (m_context->connman) { std::vector<CNodeStats> stats_temp; - m_context.connman->GetNodeStats(stats_temp); + m_context->connman->GetNodeStats(stats_temp); stats.reserve(stats_temp.size()); for (auto& node_stats_temp : stats_temp) { @@ -140,46 +142,46 @@ public: } bool getBanned(banmap_t& banmap) override { - if (m_context.banman) { - m_context.banman->GetBanned(banmap); + if (m_context->banman) { + m_context->banman->GetBanned(banmap); return true; } return false; } - bool ban(const CNetAddr& net_addr, BanReason reason, int64_t ban_time_offset) override + bool ban(const CNetAddr& net_addr, int64_t ban_time_offset) override { - if (m_context.banman) { - m_context.banman->Ban(net_addr, reason, ban_time_offset); + if (m_context->banman) { + m_context->banman->Ban(net_addr, ban_time_offset); return true; } return false; } bool unban(const CSubNet& ip) override { - if (m_context.banman) { - m_context.banman->Unban(ip); + if (m_context->banman) { + m_context->banman->Unban(ip); return true; } return false; } bool disconnectByAddress(const CNetAddr& net_addr) override { - if (m_context.connman) { - return m_context.connman->DisconnectNode(net_addr); + if (m_context->connman) { + return m_context->connman->DisconnectNode(net_addr); } return false; } bool disconnectById(NodeId id) override { - if (m_context.connman) { - return m_context.connman->DisconnectNode(id); + if (m_context->connman) { + return m_context->connman->DisconnectNode(id); } return false; } - int64_t getTotalBytesRecv() override { return m_context.connman ? m_context.connman->GetTotalBytesRecv() : 0; } - int64_t getTotalBytesSent() override { return m_context.connman ? m_context.connman->GetTotalBytesSent() : 0; } - size_t getMempoolSize() override { return m_context.mempool ? m_context.mempool->size() : 0; } - size_t getMempoolDynamicUsage() override { return m_context.mempool ? m_context.mempool->DynamicMemoryUsage() : 0; } + int64_t getTotalBytesRecv() override { return m_context->connman ? m_context->connman->GetTotalBytesRecv() : 0; } + int64_t getTotalBytesSent() override { return m_context->connman ? m_context->connman->GetTotalBytesSent() : 0; } + size_t getMempoolSize() override { return m_context->mempool ? m_context->mempool->size() : 0; } + size_t getMempoolDynamicUsage() override { return m_context->mempool ? m_context->mempool->DynamicMemoryUsage() : 0; } bool getHeaderTip(int& height, int64_t& block_time) override { LOCK(::cs_main); @@ -222,11 +224,11 @@ public: bool getImporting() override { return ::fImporting; } void setNetworkActive(bool active) override { - if (m_context.connman) { - m_context.connman->SetNetworkActive(active); + if (m_context->connman) { + m_context->connman->SetNetworkActive(active); } } - bool getNetworkActive() override { return m_context.connman && m_context.connman->GetNetworkActive(); } + bool getNetworkActive() override { return m_context->connman && m_context->connman->GetNetworkActive(); } CFeeRate estimateSmartFee(int num_blocks, bool conservative, int* returned_target = nullptr) override { FeeCalculation fee_calc; @@ -268,7 +270,7 @@ public: std::vector<std::unique_ptr<Wallet>> getWallets() override { std::vector<std::unique_ptr<Wallet>> wallets; - for (auto& client : m_context.chain_clients) { + for (auto& client : m_context->chain_clients) { auto client_wallets = client->getWallets(); std::move(client_wallets.begin(), client_wallets.end(), std::back_inserter(wallets)); } @@ -276,12 +278,12 @@ public: } std::unique_ptr<Wallet> loadWallet(const std::string& name, bilingual_str& error, std::vector<bilingual_str>& warnings) override { - return MakeWallet(LoadWallet(*m_context.chain, name, error, warnings)); + return MakeWallet(LoadWallet(*m_context->chain, name, error, warnings)); } std::unique_ptr<Wallet> createWallet(const SecureString& passphrase, uint64_t wallet_creation_flags, const std::string& name, bilingual_str& error, std::vector<bilingual_str>& warnings, WalletCreationStatus& status) override { std::shared_ptr<CWallet> wallet; - status = CreateWallet(*m_context.chain, passphrase, wallet_creation_flags, name, error, warnings, wallet); + status = CreateWallet(*m_context->chain, passphrase, wallet_creation_flags, name, error, warnings, wallet); return MakeWallet(wallet); } std::unique_ptr<Handler> handleInitMessage(InitMessageFn fn) override @@ -335,13 +337,22 @@ public: /* verification progress is unused when a header was received */ 0); })); } - NodeContext* context() override { return &m_context; } - NodeContext m_context; - util::Ref m_context_ref{m_context}; + NodeContext* context() override { return m_context; } + void setContext(NodeContext* context) override + { + m_context = context; + if (context) { + m_context_ref.Set(*context); + } else { + m_context_ref.Clear(); + } + } + NodeContext* m_context{nullptr}; + util::Ref m_context_ref; }; } // namespace -std::unique_ptr<Node> MakeNode() { return MakeUnique<NodeImpl>(); } +std::unique_ptr<Node> MakeNode(NodeContext* context) { return MakeUnique<NodeImpl>(context); } } // namespace interfaces diff --git a/src/interfaces/node.h b/src/interfaces/node.h index 0b7fb6736a..753f3e6b13 100644 --- a/src/interfaces/node.h +++ b/src/interfaces/node.h @@ -10,6 +10,7 @@ #include <net_types.h> // For banmap_t #include <netaddress.h> // For Network #include <support/allocators/secure.h> // For SecureString +#include <util/translation.h> #include <functional> #include <memory> @@ -38,6 +39,16 @@ class Handler; class Wallet; struct BlockTip; +//! Block and header tip information +struct BlockAndHeaderTipInfo +{ + int block_height; + int64_t block_time; + int header_height; + int64_t header_time; + double verification_progress; +}; + //! Top-level interface for a bitcoin node (bitcoind process). class Node { @@ -45,7 +56,7 @@ public: virtual ~Node() {} //! Send init error. - virtual void initError(const std::string& message) = 0; + virtual void initError(const bilingual_str& message) = 0; //! Set command line arguments. virtual bool parseParameters(int argc, const char* const argv[], std::string& error) = 0; @@ -65,6 +76,11 @@ public: //! Choose network parameters. virtual void selectParams(const std::string& network) = 0; + //! Read and update <datadir>/settings.json file with saved settings. This + //! needs to be called after selectParams() because the settings file + //! location is network-specific. + virtual bool initSettings(std::string& error) = 0; + //! Get the (assumed) blockchain size. virtual uint64_t getAssumedBlockchainSize() = 0; @@ -81,7 +97,7 @@ public: virtual void initParameterInteraction() = 0; //! Get warnings. - virtual std::string getWarnings() = 0; + virtual bilingual_str getWarnings() = 0; // Get log flags. virtual uint32_t getLogCategories() = 0; @@ -90,7 +106,7 @@ public: virtual bool baseInitialize() = 0; //! Start node. - virtual bool appInitMain() = 0; + virtual bool appInitMain(interfaces::BlockAndHeaderTipInfo* tip_info = nullptr) = 0; //! Stop node. virtual void appShutdown() = 0; @@ -121,7 +137,7 @@ public: virtual bool getBanned(banmap_t& banmap) = 0; //! Ban node. - virtual bool ban(const CNetAddr& net_addr, BanReason reason, int64_t ban_time_offset) = 0; + virtual bool ban(const CNetAddr& net_addr, int64_t ban_time_offset) = 0; //! Unban node. virtual bool unban(const CSubNet& ip) = 0; @@ -262,12 +278,14 @@ public: std::function<void(SynchronizationState, interfaces::BlockTip tip, double verification_progress)>; virtual std::unique_ptr<Handler> handleNotifyHeaderTip(NotifyHeaderTipFn fn) = 0; - //! Return pointer to internal chain interface, useful for testing. + //! Get and set internal node context. Useful for testing, but not + //! accessible across processes. virtual NodeContext* context() { return nullptr; } + virtual void setContext(NodeContext* context) { } }; //! Return implementation of Node interface. -std::unique_ptr<Node> MakeNode(); +std::unique_ptr<Node> MakeNode(NodeContext* context = nullptr); //! Block tip (could be a header or not, depends on the subscribed signal). struct BlockTip { diff --git a/src/interfaces/wallet.cpp b/src/interfaces/wallet.cpp index cec75030ad..7fd24425cf 100644 --- a/src/interfaces/wallet.cpp +++ b/src/interfaces/wallet.cpp @@ -9,13 +9,16 @@ #include <interfaces/handler.h> #include <policy/fees.h> #include <primitives/transaction.h> +#include <rpc/server.h> #include <script/standard.h> #include <support/allocators/secure.h> #include <sync.h> -#include <ui_interface.h> #include <uint256.h> #include <util/check.h> +#include <util/ref.h> #include <util/system.h> +#include <util/ui_change_type.h> +#include <wallet/context.h> #include <wallet/feebumper.h> #include <wallet/fees.h> #include <wallet/ismine.h> @@ -332,9 +335,10 @@ public: bool sign, bool bip32derivs, PartiallySignedTransaction& psbtx, - bool& complete) override + bool& complete, + size_t* n_signed) override { - return m_wallet->FillPSBT(psbtx, complete, sighash_type, sign, bip32derivs); + return m_wallet->FillPSBT(psbtx, complete, sighash_type, sign, bip32derivs, n_signed); } WalletBalances getBalances() override { @@ -434,7 +438,6 @@ public: bool canGetAddresses() override { return m_wallet->CanGetAddresses(); } bool privateKeysDisabled() override { return m_wallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS); } OutputType getDefaultAddressType() override { return m_wallet->m_default_address_type; } - OutputType getDefaultChangeType() override { return m_wallet->m_default_change_type; } CAmount getDefaultMaxTxFee() override { return m_wallet->m_default_max_tx_fee; } void remove() override { @@ -480,18 +483,24 @@ public: class WalletClientImpl : public ChainClient { public: - WalletClientImpl(Chain& chain, std::vector<std::string> wallet_filenames) - : m_chain(chain), m_wallet_filenames(std::move(wallet_filenames)) + WalletClientImpl(Chain& chain, ArgsManager& args, std::vector<std::string> wallet_filenames) + : m_wallet_filenames(std::move(wallet_filenames)) { + m_context.chain = &chain; + m_context.args = &args; } void registerRpcs() override { - g_rpc_chain = &m_chain; - return RegisterWalletRPCCommands(m_chain, m_rpc_handlers); + for (const CRPCCommand& command : GetWalletRPCCommands()) { + m_rpc_commands.emplace_back(command.category, command.name, [this, &command](const JSONRPCRequest& request, UniValue& result, bool last_handler) { + return command.actor({request, m_context}, result, last_handler); + }, command.argNames, command.unique_id); + m_rpc_handlers.emplace_back(m_context.chain->handleRpc(m_rpc_commands.back())); + } } - bool verify() override { return VerifyWallets(m_chain, m_wallet_filenames); } - bool load() override { return LoadWallets(m_chain, m_wallet_filenames); } - void start(CScheduler& scheduler) override { return StartWallets(scheduler); } + bool verify() override { return VerifyWallets(*m_context.chain, m_wallet_filenames); } + bool load() override { return LoadWallets(*m_context.chain, m_wallet_filenames); } + void start(CScheduler& scheduler) override { return StartWallets(scheduler, *Assert(m_context.args)); } void flush() override { return FlushWallets(); } void stop() override { return StopWallets(); } void setMockTime(int64_t time) override { return SetMockTime(time); } @@ -505,18 +514,19 @@ public: } ~WalletClientImpl() override { UnloadWallets(); } - Chain& m_chain; - std::vector<std::string> m_wallet_filenames; + WalletContext m_context; + const std::vector<std::string> m_wallet_filenames; std::vector<std::unique_ptr<Handler>> m_rpc_handlers; + std::list<CRPCCommand> m_rpc_commands; }; } // namespace std::unique_ptr<Wallet> MakeWallet(const std::shared_ptr<CWallet>& wallet) { return wallet ? MakeUnique<WalletImpl>(wallet) : nullptr; } -std::unique_ptr<ChainClient> MakeWalletClient(Chain& chain, std::vector<std::string> wallet_filenames) +std::unique_ptr<ChainClient> MakeWalletClient(Chain& chain, ArgsManager& args, std::vector<std::string> wallet_filenames) { - return MakeUnique<WalletClientImpl>(chain, std::move(wallet_filenames)); + return MakeUnique<WalletClientImpl>(chain, args, std::move(wallet_filenames)); } } // namespace interfaces diff --git a/src/interfaces/wallet.h b/src/interfaces/wallet.h index 67569a3e55..3cdadbc72e 100644 --- a/src/interfaces/wallet.h +++ b/src/interfaces/wallet.h @@ -9,8 +9,8 @@ #include <pubkey.h> // For CKeyID and CScriptID (definitions needed in CTxDestination instantiation) #include <script/standard.h> // For CTxDestination #include <support/allocators/secure.h> // For SecureString -#include <ui_interface.h> // For ChangeType #include <util/message.h> +#include <util/ui_change_type.h> #include <functional> #include <map> @@ -197,7 +197,8 @@ public: bool sign, bool bip32derivs, PartiallySignedTransaction& psbtx, - bool& complete) = 0; + bool& complete, + size_t* n_signed) = 0; //! Get balances. virtual WalletBalances getBalances() = 0; @@ -255,9 +256,6 @@ public: // Get default address type. virtual OutputType getDefaultAddressType() = 0; - // Get default change type. - virtual OutputType getDefaultChangeType() = 0; - //! Get max tx fee. virtual CAmount getDefaultMaxTxFee() = 0; diff --git a/src/key.cpp b/src/key.cpp index b6ed29e8e3..4ed74a39b1 100644 --- a/src/key.cpp +++ b/src/key.cpp @@ -31,46 +31,46 @@ static secp256k1_context* secp256k1_context_sign = nullptr; * * out32 must point to an output buffer of length at least 32 bytes. */ -static int ec_privkey_import_der(const secp256k1_context* ctx, unsigned char *out32, const unsigned char *privkey, size_t privkeylen) { - const unsigned char *end = privkey + privkeylen; +static int ec_seckey_import_der(const secp256k1_context* ctx, unsigned char *out32, const unsigned char *seckey, size_t seckeylen) { + const unsigned char *end = seckey + seckeylen; memset(out32, 0, 32); /* sequence header */ - if (end - privkey < 1 || *privkey != 0x30u) { + if (end - seckey < 1 || *seckey != 0x30u) { return 0; } - privkey++; + seckey++; /* sequence length constructor */ - if (end - privkey < 1 || !(*privkey & 0x80u)) { + if (end - seckey < 1 || !(*seckey & 0x80u)) { return 0; } - ptrdiff_t lenb = *privkey & ~0x80u; privkey++; + ptrdiff_t lenb = *seckey & ~0x80u; seckey++; if (lenb < 1 || lenb > 2) { return 0; } - if (end - privkey < lenb) { + if (end - seckey < lenb) { return 0; } /* sequence length */ - ptrdiff_t len = privkey[lenb-1] | (lenb > 1 ? privkey[lenb-2] << 8 : 0u); - privkey += lenb; - if (end - privkey < len) { + ptrdiff_t len = seckey[lenb-1] | (lenb > 1 ? seckey[lenb-2] << 8 : 0u); + seckey += lenb; + if (end - seckey < len) { return 0; } /* sequence element 0: version number (=1) */ - if (end - privkey < 3 || privkey[0] != 0x02u || privkey[1] != 0x01u || privkey[2] != 0x01u) { + if (end - seckey < 3 || seckey[0] != 0x02u || seckey[1] != 0x01u || seckey[2] != 0x01u) { return 0; } - privkey += 3; + seckey += 3; /* sequence element 1: octet string, up to 32 bytes */ - if (end - privkey < 2 || privkey[0] != 0x04u) { + if (end - seckey < 2 || seckey[0] != 0x04u) { return 0; } - ptrdiff_t oslen = privkey[1]; - privkey += 2; - if (oslen > 32 || end - privkey < oslen) { + ptrdiff_t oslen = seckey[1]; + seckey += 2; + if (oslen > 32 || end - seckey < oslen) { return 0; } - memcpy(out32 + (32 - oslen), privkey, oslen); + memcpy(out32 + (32 - oslen), seckey, oslen); if (!secp256k1_ec_seckey_verify(ctx, out32)) { memset(out32, 0, 32); return 0; @@ -83,17 +83,17 @@ static int ec_privkey_import_der(const secp256k1_context* ctx, unsigned char *ou * <http://www.secg.org/sec1-v2.pdf>. The optional parameters and publicKey fields are * included. * - * privkey must point to an output buffer of length at least CKey::SIZE bytes. - * privkeylen must initially be set to the size of the privkey buffer. Upon return it + * seckey must point to an output buffer of length at least CKey::SIZE bytes. + * seckeylen must initially be set to the size of the seckey buffer. Upon return it * will be set to the number of bytes used in the buffer. * key32 must point to a 32-byte raw private key. */ -static int ec_privkey_export_der(const secp256k1_context *ctx, unsigned char *privkey, size_t *privkeylen, const unsigned char *key32, bool compressed) { - assert(*privkeylen >= CKey::SIZE); +static int ec_seckey_export_der(const secp256k1_context *ctx, unsigned char *seckey, size_t *seckeylen, const unsigned char *key32, bool compressed) { + assert(*seckeylen >= CKey::SIZE); secp256k1_pubkey pubkey; size_t pubkeylen = 0; if (!secp256k1_ec_pubkey_create(ctx, &pubkey, key32)) { - *privkeylen = 0; + *seckeylen = 0; return 0; } if (compressed) { @@ -111,15 +111,15 @@ static int ec_privkey_export_der(const secp256k1_context *ctx, unsigned char *pr 0xFF,0xFF,0xFF,0xFF,0xFE,0xBA,0xAE,0xDC,0xE6,0xAF,0x48,0xA0,0x3B,0xBF,0xD2,0x5E, 0x8C,0xD0,0x36,0x41,0x41,0x02,0x01,0x01,0xA1,0x24,0x03,0x22,0x00 }; - unsigned char *ptr = privkey; + unsigned char *ptr = seckey; memcpy(ptr, begin, sizeof(begin)); ptr += sizeof(begin); memcpy(ptr, key32, 32); ptr += 32; memcpy(ptr, middle, sizeof(middle)); ptr += sizeof(middle); pubkeylen = CPubKey::COMPRESSED_SIZE; secp256k1_ec_pubkey_serialize(ctx, ptr, &pubkeylen, &pubkey, SECP256K1_EC_COMPRESSED); ptr += pubkeylen; - *privkeylen = ptr - privkey; - assert(*privkeylen == CKey::COMPRESSED_SIZE); + *seckeylen = ptr - seckey; + assert(*seckeylen == CKey::COMPRESSED_SIZE); } else { static const unsigned char begin[] = { 0x30,0x82,0x01,0x13,0x02,0x01,0x01,0x04,0x20 @@ -137,15 +137,15 @@ static int ec_privkey_export_der(const secp256k1_context *ctx, unsigned char *pr 0xFF,0xFF,0xFF,0xFF,0xFE,0xBA,0xAE,0xDC,0xE6,0xAF,0x48,0xA0,0x3B,0xBF,0xD2,0x5E, 0x8C,0xD0,0x36,0x41,0x41,0x02,0x01,0x01,0xA1,0x44,0x03,0x42,0x00 }; - unsigned char *ptr = privkey; + unsigned char *ptr = seckey; memcpy(ptr, begin, sizeof(begin)); ptr += sizeof(begin); memcpy(ptr, key32, 32); ptr += 32; memcpy(ptr, middle, sizeof(middle)); ptr += sizeof(middle); pubkeylen = CPubKey::SIZE; secp256k1_ec_pubkey_serialize(ctx, ptr, &pubkeylen, &pubkey, SECP256K1_EC_UNCOMPRESSED); ptr += pubkeylen; - *privkeylen = ptr - privkey; - assert(*privkeylen == CKey::SIZE); + *seckeylen = ptr - seckey; + assert(*seckeylen == CKey::SIZE); } return 1; } @@ -165,20 +165,20 @@ void CKey::MakeNewKey(bool fCompressedIn) { bool CKey::Negate() { assert(fValid); - return secp256k1_ec_privkey_negate(secp256k1_context_sign, keydata.data()); + return secp256k1_ec_seckey_negate(secp256k1_context_sign, keydata.data()); } CPrivKey CKey::GetPrivKey() const { assert(fValid); - CPrivKey privkey; + CPrivKey seckey; int ret; - size_t privkeylen; - privkey.resize(SIZE); - privkeylen = SIZE; - ret = ec_privkey_export_der(secp256k1_context_sign, privkey.data(), &privkeylen, begin(), fCompressed); + size_t seckeylen; + seckey.resize(SIZE); + seckeylen = SIZE; + ret = ec_seckey_export_der(secp256k1_context_sign, seckey.data(), &seckeylen, begin(), fCompressed); assert(ret); - privkey.resize(privkeylen); - return privkey; + seckey.resize(seckeylen); + return seckey; } CPubKey CKey::GetPubKey() const { @@ -237,7 +237,7 @@ bool CKey::VerifyPubKey(const CPubKey& pubkey) const { std::string str = "Bitcoin key verification\n"; GetRandBytes(rnd, sizeof(rnd)); uint256 hash; - CHash256().Write((unsigned char*)str.data(), str.size()).Write(rnd, sizeof(rnd)).Finalize(hash.begin()); + CHash256().Write(MakeUCharSpan(str)).Write(rnd).Finalize(hash); std::vector<unsigned char> vchSig; Sign(hash, vchSig); return pubkey.Verify(hash, vchSig); @@ -258,8 +258,8 @@ bool CKey::SignCompact(const uint256 &hash, std::vector<unsigned char>& vchSig) return true; } -bool CKey::Load(const CPrivKey &privkey, const CPubKey &vchPubKey, bool fSkipCheck=false) { - if (!ec_privkey_import_der(secp256k1_context_sign, (unsigned char*)begin(), privkey.data(), privkey.size())) +bool CKey::Load(const CPrivKey &seckey, const CPubKey &vchPubKey, bool fSkipCheck=false) { + if (!ec_seckey_import_der(secp256k1_context_sign, (unsigned char*)begin(), seckey.data(), seckey.size())) return false; fCompressed = vchPubKey.IsCompressed(); fValid = true; @@ -284,7 +284,7 @@ bool CKey::Derive(CKey& keyChild, ChainCode &ccChild, unsigned int nChild, const } memcpy(ccChild.begin(), vout.data()+32, 32); memcpy((unsigned char*)keyChild.begin(), begin(), 32); - bool ret = secp256k1_ec_privkey_tweak_add(secp256k1_context_sign, (unsigned char*)keyChild.begin(), vout.data()); + bool ret = secp256k1_ec_seckey_tweak_add(secp256k1_context_sign, (unsigned char*)keyChild.begin(), vout.data()); keyChild.fCompressed = true; keyChild.fValid = ret; return ret; diff --git a/src/logging.cpp b/src/logging.cpp index fe58ae9e73..35e0754f20 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -95,15 +95,7 @@ void BCLog::Logger::EnableCategory(BCLog::LogFlags flag) bool BCLog::Logger::EnableCategory(const std::string& str) { BCLog::LogFlags flag; - if (!GetLogCategory(flag, str)) { - if (str == "db") { - // DEPRECATION: Added in 0.20, should start returning an error in 0.21 - LogPrintf("Warning: logging category 'db' is deprecated, use 'walletdb' instead\n"); - EnableCategory(BCLog::WALLETDB); - return true; - } - return false; - } + if (!GetLogCategory(flag, str)) return false; EnableCategory(flag); return true; } diff --git a/src/merkleblock.cpp b/src/merkleblock.cpp index 8072b12119..b571d463c9 100644 --- a/src/merkleblock.cpp +++ b/src/merkleblock.cpp @@ -70,7 +70,7 @@ uint256 CPartialMerkleTree::CalcHash(int height, unsigned int pos, const std::ve else right = left; // combine subhashes - return Hash(left.begin(), left.end(), right.begin(), right.end()); + return Hash(left, right); } } @@ -126,7 +126,7 @@ uint256 CPartialMerkleTree::TraverseAndExtract(int height, unsigned int pos, uns right = left; } // and combine them before returning - return Hash(left.begin(), left.end(), right.begin(), right.end()); + return Hash(left, right); } } diff --git a/src/miner.cpp b/src/miner.cpp index d9dcbe8a70..41a835f70a 100644 --- a/src/miner.cpp +++ b/src/miner.cpp @@ -109,7 +109,7 @@ std::unique_ptr<CBlockTemplate> BlockAssembler::CreateNewBlock(const CScript& sc if(!pblocktemplate.get()) return nullptr; - pblock = &pblocktemplate->block; // pointer for convenience + CBlock* const pblock = &pblocktemplate->block; // pointer for convenience // Add dummy coinbase tx as first transaction pblock->vtx.emplace_back(); @@ -226,7 +226,7 @@ bool BlockAssembler::TestPackageTransactions(const CTxMemPool::setEntries& packa void BlockAssembler::AddToBlock(CTxMemPool::txiter iter) { - pblock->vtx.emplace_back(iter->GetSharedTx()); + pblocktemplate->block.vtx.emplace_back(iter->GetSharedTx()); pblocktemplate->vTxFees.push_back(iter->GetFee()); pblocktemplate->vTxSigOpsCost.push_back(iter->GetSigOpCost()); nBlockWeight += iter->GetTxWeight(); diff --git a/src/miner.h b/src/miner.h index 69296f9078..096585dfe4 100644 --- a/src/miner.h +++ b/src/miner.h @@ -128,8 +128,6 @@ class BlockAssembler private: // The constructed block template std::unique_ptr<CBlockTemplate> pblocktemplate; - // A convenience pointer that always refers to the CBlock in pblocktemplate - CBlock* pblock; // Configuration parameters for the block size bool fIncludeWitness; diff --git a/src/net.cpp b/src/net.cpp index 707412bb32..6c1980735c 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -14,12 +14,12 @@ #include <clientversion.h> #include <consensus/consensus.h> #include <crypto/sha256.h> -#include <netbase.h> #include <net_permissions.h> +#include <netbase.h> +#include <node/ui_interface.h> #include <protocol.h> #include <random.h> #include <scheduler.h> -#include <ui_interface.h> #include <util/strencodings.h> #include <util/translation.h> @@ -42,6 +42,7 @@ static_assert(MINIUPNPC_API_VERSION >= 10, "miniUPnPc API version >= 10 assumed"); #endif +#include <cstdint> #include <unordered_map> #include <math.h> @@ -53,10 +54,17 @@ static constexpr std::chrono::minutes DUMP_PEERS_INTERVAL{15}; static constexpr int DNSSEEDS_TO_QUERY_AT_ONCE = 3; /** How long to delay before querying DNS seeds + * + * If we have more than THRESHOLD entries in addrman, then it's likely + * that we got those addresses from having previously connected to the P2P + * network, and that we'll be able to successfully reconnect to the P2P + * network via contacting one of them. So if that's the case, spend a + * little longer trying to connect to known peers before querying the + * DNS seeds. */ -static constexpr std::chrono::seconds DNSSEEDS_DELAY_FEW_PEERS{11}; // 11sec -static constexpr std::chrono::seconds DNSSEEDS_DELAY_MANY_PEERS{300}; // 5min -static constexpr int DNSSEEDS_DELAY_PEER_THRESHOLD = 1000; // "many" vs "few" peers -- you should only get this many if you've been on the live network +static constexpr std::chrono::seconds DNSSEEDS_DELAY_FEW_PEERS{11}; +static constexpr std::chrono::minutes DNSSEEDS_DELAY_MANY_PEERS{5}; +static constexpr int DNSSEEDS_DELAY_PEER_THRESHOLD = 1000; // "many" vs "few" peers // We add a random period time (0 to 1 seconds) to feeler connections to prevent synchronization. #define FEELER_SLEEP_WINDOW 1 @@ -97,15 +105,15 @@ std::map<CNetAddr, LocalServiceInfo> mapLocalHost GUARDED_BY(cs_mapLocalHost); static bool vfLimited[NET_MAX] GUARDED_BY(cs_mapLocalHost) = {}; std::string strSubVersion; -void CConnman::AddOneShot(const std::string& strDest) +void CConnman::AddAddrFetch(const std::string& strDest) { - LOCK(cs_vOneShots); - vOneShots.push_back(strDest); + LOCK(m_addr_fetches_mutex); + m_addr_fetches.push_back(strDest); } -unsigned short GetListenPort() +uint16_t GetListenPort() { - return (unsigned short)(gArgs.GetArg("-port", Params().GetDefaultPort())); + return (uint16_t)(gArgs.GetArg("-port", Params().GetDefaultPort())); } // find 'best' local address for a particular peer @@ -338,7 +346,7 @@ bool CConnman::CheckIncomingNonce(uint64_t nonce) { LOCK(cs_vNodes); for (const CNode* pnode : vNodes) { - if (!pnode->fSuccessfullyConnected && !pnode->fInbound && pnode->GetLocalNonce() == nonce) + if (!pnode->fSuccessfullyConnected && !pnode->IsInboundConn() && pnode->GetLocalNonce() == nonce) return false; } return true; @@ -360,8 +368,10 @@ static CAddress GetBindAddress(SOCKET sock) return addr_bind; } -CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCountFailure, bool manual_connection, bool block_relay_only) +CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCountFailure, ConnectionType conn_type) { + assert(conn_type != ConnectionType::INBOUND); + if (pszDest == nullptr) { if (IsLocal(addrConnect)) return nullptr; @@ -424,7 +434,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo if (hSocket == INVALID_SOCKET) { return nullptr; } - connected = ConnectSocketDirectly(addrConnect, hSocket, nConnectTimeout, manual_connection); + connected = ConnectSocketDirectly(addrConnect, hSocket, nConnectTimeout, conn_type == ConnectionType::MANUAL); } if (!proxyConnectionFailed) { // If a connection to the node was attempted, and failure (if any) is not caused by a problem connecting to @@ -451,7 +461,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo NodeId id = GetNewNodeId(); uint64_t nonce = GetDeterministicRandomizer(RANDOMIZER_ID_LOCALHOSTNONCE).Write(id).Finalize(); CAddress addr_bind = GetBindAddress(hSocket); - CNode* pnode = new CNode(id, nLocalServices, GetBestHeight(), hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", false, block_relay_only); + CNode* pnode = new CNode(id, nLocalServices, GetBestHeight(), hSocket, addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type); pnode->AddRef(); // We're making a new connection, harvest entropy from the time (and our peer count) @@ -528,8 +538,8 @@ void CNode::copyStats(CNodeStats &stats, const std::vector<bool> &m_asmap) LOCK(cs_SubVer); X(cleanSubVer); } - X(fInbound); - X(m_manual_connection); + stats.fInbound = IsInboundConn(); + stats.m_manual_connection = IsManualConn(); X(nStartingHeight); { LOCK(cs_vSend); @@ -556,15 +566,15 @@ void CNode::copyStats(CNodeStats &stats, const std::vector<bool> &m_asmap) // since pingtime does not update until the ping is complete, which might take a while. // So, if a ping is taking an unusually long time in flight, // the caller can immediately detect that this is happening. - int64_t nPingUsecWait = 0; - if ((0 != nPingNonceSent) && (0 != nPingUsecStart)) { - nPingUsecWait = GetTimeMicros() - nPingUsecStart; + std::chrono::microseconds ping_wait{0}; + if ((0 != nPingNonceSent) && (0 != m_ping_start.load().count())) { + ping_wait = GetTime<std::chrono::microseconds>() - m_ping_start.load(); } // Raw ping time is in microseconds, but show it to user as whole seconds (Bitcoin users should be well used to small numbers with many decimal places by now :) stats.m_ping_usec = nPingUsecTime; stats.m_min_ping_usec = nMinPingUsecTime; - stats.m_ping_wait_usec = nPingUsecWait; + stats.m_ping_wait_usec = count_microseconds(ping_wait); // Leave string empty if addrLocal invalid (not filled in yet) CService addrLocalUnlocked = GetAddrLocal(); @@ -575,9 +585,9 @@ void CNode::copyStats(CNodeStats &stats, const std::vector<bool> &m_asmap) bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete) { complete = false; - int64_t nTimeMicros = GetTimeMicros(); + const auto time = GetTime<std::chrono::microseconds>(); LOCK(cs_vRecv); - nLastRecv = nTimeMicros / 1000000; + nLastRecv = std::chrono::duration_cast<std::chrono::seconds>(time).count(); nRecvBytes += nBytes; while (nBytes > 0) { // absorb network data @@ -589,7 +599,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete if (m_deserializer->Complete()) { // decompose a transport agnostic CNetMessage from the deserializer - CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), nTimeMicros); + CNetMessage msg = m_deserializer->GetMessage(Params().MessageStart(), time); //store received bytes per message command //to prevent a memory DOS, only allow valid commands @@ -677,7 +687,7 @@ int V1TransportDeserializer::readData(const char *pch, unsigned int nBytes) vRecv.resize(std::min(hdr.nMessageSize, nDataPos + nCopy + 256 * 1024)); } - hasher.Write((const unsigned char*)pch, nCopy); + hasher.Write({(const unsigned char*)pch, nCopy}); memcpy(&vRecv[nDataPos], pch, nCopy); nDataPos += nCopy; @@ -688,11 +698,12 @@ const uint256& V1TransportDeserializer::GetMessageHash() const { assert(Complete()); if (data_hash.IsNull()) - hasher.Finalize(data_hash.begin()); + hasher.Finalize(data_hash); return data_hash; } -CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time) { +CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageStartChars& message_start, const std::chrono::microseconds time) +{ // decompose a single CNetMessage from the TransportDeserializer CNetMessage msg(std::move(vRecv)); @@ -713,8 +724,8 @@ CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageSta if (!msg.m_valid_checksum) { LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s\n", SanitizeString(msg.m_command), msg.m_message_size, - HexStr(hash.begin(), hash.begin()+CMessageHeader::CHECKSUM_SIZE), - HexStr(hdr.pchChecksum, hdr.pchChecksum+CMessageHeader::CHECKSUM_SIZE)); + HexStr(Span<uint8_t>(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)), + HexStr(hdr.pchChecksum)); } // store receive time @@ -727,10 +738,10 @@ CNetMessage V1TransportDeserializer::GetMessage(const CMessageHeader::MessageSta void V1TransportSerializer::prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) { // create dbl-sha256 checksum - uint256 hash = Hash(msg.data.begin(), msg.data.end()); + uint256 hash = Hash(msg.data); // create header - CMessageHeader hdr(Params().MessageStart(), msg.command.c_str(), msg.data.size()); + CMessageHeader hdr(Params().MessageStart(), msg.m_type.c_str(), msg.data.size()); memcpy(hdr.pchChecksum, hash.begin(), CMessageHeader::CHECKSUM_SIZE); // serialize header @@ -863,7 +874,7 @@ bool CConnman::AttemptToEvictConnection() for (const CNode* node : vNodes) { if (node->HasPermission(PF_NOBAN)) continue; - if (!node->fInbound) + if (!node->IsInboundConn()) continue; if (node->fDisconnect) continue; @@ -974,7 +985,7 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { { LOCK(cs_vNodes); for (const CNode* pnode : vNodes) { - if (pnode->fInbound) nInbound++; + if (pnode->IsInboundConn()) nInbound++; } } @@ -1003,17 +1014,24 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { // on all platforms. Set it again here just to be sure. SetSocketNoDelay(hSocket); - int bannedlevel = m_banman ? m_banman->IsBannedLevel(addr) : 0; - - // Don't accept connections from banned peers, but if our inbound slots aren't almost full, accept - // if the only banning reason was an automatic misbehavior ban. - if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::PF_NOBAN) && bannedlevel > ((nInbound + 1 < nMaxInbound) ? 1 : 0)) + // Don't accept connections from banned peers. + bool banned = m_banman && m_banman->IsBanned(addr); + if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::PF_NOBAN) && banned) { LogPrint(BCLog::NET, "connection from %s dropped (banned)\n", addr.ToString()); CloseSocket(hSocket); return; } + // Only accept connections from discouraged peers if our inbound slots aren't (almost) full. + bool discouraged = m_banman && m_banman->IsDiscouraged(addr); + if (!NetPermissions::HasFlag(permissionFlags, NetPermissionFlags::PF_NOBAN) && nInbound + 1 >= nMaxInbound && discouraged) + { + LogPrint(BCLog::NET, "connection from %s dropped (discouraged)\n", addr.ToString()); + CloseSocket(hSocket); + return; + } + if (nInbound >= nMaxInbound) { if (!AttemptToEvictConnection()) { @@ -1032,12 +1050,12 @@ void CConnman::AcceptConnection(const ListenSocket& hListenSocket) { if (NetPermissions::HasFlag(permissionFlags, PF_BLOOMFILTER)) { nodeServices = static_cast<ServiceFlags>(nodeServices | NODE_BLOOM); } - CNode* pnode = new CNode(id, nodeServices, GetBestHeight(), hSocket, addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", true); + CNode* pnode = new CNode(id, nodeServices, GetBestHeight(), hSocket, addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND); pnode->AddRef(); pnode->m_permissionFlags = permissionFlags; // If this flag is present, the user probably expect that RPC and QT report it as whitelisted (backward compatibility) pnode->m_legacyWhitelisted = legacyWhitelisted; - pnode->m_prefer_evict = bannedlevel > 0; + pnode->m_prefer_evict = discouraged; m_msgproc->InitializeNode(pnode); LogPrint(BCLog::NET, "connection from %s accepted\n", addr.ToString()); @@ -1096,12 +1114,9 @@ void CConnman::DisconnectNodes() if (pnode->GetRefCount() <= 0) { bool fDelete = false; { - TRY_LOCK(pnode->cs_inventory, lockInv); - if (lockInv) { - TRY_LOCK(pnode->cs_vSend, lockSend); - if (lockSend) { - fDelete = true; - } + TRY_LOCK(pnode->cs_vSend, lockSend); + if (lockSend) { + fDelete = true; } } if (fDelete) { @@ -1147,9 +1162,9 @@ void CConnman::InactivityCheck(CNode *pnode) LogPrintf("socket receive timeout: %is\n", nTime - pnode->nLastRecv); pnode->fDisconnect = true; } - else if (pnode->nPingNonceSent && pnode->nPingUsecStart + TIMEOUT_INTERVAL * 1000000 < GetTimeMicros()) + else if (pnode->nPingNonceSent && pnode->m_ping_start.load() + std::chrono::seconds{TIMEOUT_INTERVAL} < GetTime<std::chrono::microseconds>()) { - LogPrintf("ping timeout: %fs\n", 0.000001 * (GetTimeMicros() - pnode->nPingUsecStart)); + LogPrintf("ping timeout: %fs\n", 0.000001 * count_microseconds(GetTime<std::chrono::microseconds>() - pnode->m_ping_start.load())); pnode->fDisconnect = true; } else if (!pnode->fSuccessfullyConnected) @@ -1595,6 +1610,8 @@ void CConnman::ThreadDNSAddressSeed() seeds_right_now = seeds.size(); } else if (addrman.size() == 0) { // If we have no known peers, query all. + // This will occur on the first run, or if peers.dat has been + // deleted. seeds_right_now = seeds.size(); } @@ -1620,6 +1637,9 @@ void CConnman::ThreadDNSAddressSeed() LogPrintf("Waiting %d seconds before querying DNS seeds.\n", seeds_wait_time.count()); std::chrono::seconds to_wait = seeds_wait_time; while (to_wait.count() > 0) { + // if sleeping for the MANY_PEERS interval, wake up + // early to see if we have enough peers and can stop + // this thread entirely freeing up its resources std::chrono::seconds w = std::min(DNSSEEDS_DELAY_FEW_PEERS, to_wait); if (!interruptNet.sleep_for(w)) return; to_wait -= w; @@ -1628,7 +1648,7 @@ void CConnman::ThreadDNSAddressSeed() { LOCK(cs_vNodes); for (const CNode* pnode : vNodes) { - nRelevant += pnode->fSuccessfullyConnected && !pnode->fFeeler && !pnode->fOneShot && !pnode->m_manual_connection && !pnode->fInbound; + if (pnode->fSuccessfullyConnected && pnode->IsOutboundOrBlockRelayConn()) ++nRelevant; } } if (nRelevant >= 2) { @@ -1646,7 +1666,7 @@ void CConnman::ThreadDNSAddressSeed() if (interruptNet) return; - // hold off on querying seeds if p2p network deactivated + // hold off on querying seeds if P2P network deactivated if (!fNetworkActive) { LogPrintf("Waiting for network to be reactivated before querying DNS seeds.\n"); do { @@ -1656,7 +1676,7 @@ void CConnman::ThreadDNSAddressSeed() LogPrintf("Loading addresses from DNS seed %s\n", seed); if (HaveNameProxy()) { - AddOneShot(seed); + AddAddrFetch(seed); } else { std::vector<CNetAddr> vIPs; std::vector<CAddress> vAdd; @@ -1678,8 +1698,8 @@ void CConnman::ThreadDNSAddressSeed() addrman.Add(vAdd, resolveSource); } else { // We now avoid directly using results from DNS Seeds which do not support service bit filtering, - // instead using them as a oneshot to get nodes with our desired service bits. - AddOneShot(seed); + // instead using them as a addrfetch to get nodes with our desired service bits. + AddAddrFetch(seed); } } --seeds_right_now; @@ -1687,17 +1707,6 @@ void CConnman::ThreadDNSAddressSeed() LogPrintf("%d addresses found from DNS seeds\n", found); } - - - - - - - - - - - void CConnman::DumpAddresses() { int64_t nStart = GetTimeMillis(); @@ -1709,20 +1718,20 @@ void CConnman::DumpAddresses() addrman.size(), GetTimeMillis() - nStart); } -void CConnman::ProcessOneShot() +void CConnman::ProcessAddrFetch() { std::string strDest; { - LOCK(cs_vOneShots); - if (vOneShots.empty()) + LOCK(m_addr_fetches_mutex); + if (m_addr_fetches.empty()) return; - strDest = vOneShots.front(); - vOneShots.pop_front(); + strDest = m_addr_fetches.front(); + m_addr_fetches.pop_front(); } CAddress addr; CSemaphoreGrant grant(*semOutbound, true); if (grant) { - OpenNetworkConnection(addr, false, &grant, strDest.c_str(), true); + OpenNetworkConnection(addr, false, &grant, strDest.c_str(), ConnectionType::ADDR_FETCH); } } @@ -1749,7 +1758,7 @@ int CConnman::GetExtraOutboundCount() { LOCK(cs_vNodes); for (const CNode* pnode : vNodes) { - if (!pnode->fInbound && !pnode->m_manual_connection && !pnode->fFeeler && !pnode->fDisconnect && !pnode->fOneShot && pnode->fSuccessfullyConnected) { + if (pnode->fSuccessfullyConnected && !pnode->fDisconnect && pnode->IsOutboundOrBlockRelayConn()) { ++nOutbound; } } @@ -1764,11 +1773,11 @@ void CConnman::ThreadOpenConnections(const std::vector<std::string> connect) { for (int64_t nLoop = 0;; nLoop++) { - ProcessOneShot(); + ProcessAddrFetch(); for (const std::string& strAddr : connect) { CAddress addr(CService(), NODE_NONE); - OpenNetworkConnection(addr, false, nullptr, strAddr.c_str(), false, false, true); + OpenNetworkConnection(addr, false, nullptr, strAddr.c_str(), ConnectionType::MANUAL); for (int i = 0; i < 10 && i < nLoop; i++) { if (!interruptNet.sleep_for(std::chrono::milliseconds(500))) @@ -1787,7 +1796,7 @@ void CConnman::ThreadOpenConnections(const std::vector<std::string> connect) int64_t nNextFeeler = PoissonNextSend(nStart*1000*1000, FEELER_INTERVAL); while (!interruptNet) { - ProcessOneShot(); + ProcessAddrFetch(); if (!interruptNet.sleep_for(std::chrono::milliseconds(500))) return; @@ -1797,6 +1806,9 @@ void CConnman::ThreadOpenConnections(const std::vector<std::string> connect) return; // Add seed nodes if DNS seeds are all down (an infrastructure attack?). + // Note that we only do this if we started with an empty peers.dat, + // (in which case we will query DNS seeds immediately) *and* the DNS + // seeds have not returned any results. if (addrman.size() == 0 && (GetTime() - nStart > 60)) { static bool done = false; if (!done) { @@ -1817,21 +1829,27 @@ void CConnman::ThreadOpenConnections(const std::vector<std::string> connect) int nOutboundFullRelay = 0; int nOutboundBlockRelay = 0; std::set<std::vector<unsigned char> > setConnected; + { LOCK(cs_vNodes); for (const CNode* pnode : vNodes) { - if (!pnode->fInbound && !pnode->m_manual_connection) { - // Netgroups for inbound and addnode peers are not excluded because our goal here - // is to not use multiple of our limited outbound slots on a single netgroup - // but inbound and addnode peers do not use our outbound slots. Inbound peers - // also have the added issue that they're attacker controlled and could be used - // to prevent us from connecting to particular hosts if we used them here. - setConnected.insert(pnode->addr.GetGroup(addrman.m_asmap)); - if (pnode->m_tx_relay == nullptr) { - nOutboundBlockRelay++; - } else if (!pnode->fFeeler) { - nOutboundFullRelay++; - } + if (pnode->IsFullOutboundConn()) nOutboundFullRelay++; + if (pnode->IsBlockOnlyConn()) nOutboundBlockRelay++; + + // Netgroups for inbound and manual peers are not excluded because our goal here + // is to not use multiple of our limited outbound slots on a single netgroup + // but inbound and manual peers do not use our outbound slots. Inbound peers + // also have the added issue that they could be attacker controlled and used + // to prevent us from connecting to particular hosts if we used them here. + switch(pnode->m_conn_type){ + case ConnectionType::INBOUND: + case ConnectionType::MANUAL: + break; + case ConnectionType::OUTBOUND: + case ConnectionType::BLOCK_RELAY: + case ConnectionType::ADDR_FETCH: + case ConnectionType::FEELER: + setConnected.insert(pnode->addr.GetGroup(addrman.m_asmap)); } } } @@ -1924,14 +1942,24 @@ void CConnman::ThreadOpenConnections(const std::vector<std::string> connect) LogPrint(BCLog::NET, "Making feeler connection to %s\n", addrConnect.ToString()); } - // Open this connection as block-relay-only if we're already at our - // full-relay capacity, but not yet at our block-relay peer limit. - // (It should not be possible for fFeeler to be set if we're not - // also at our block-relay peer limit, but check against that as - // well for sanity.) - bool block_relay_only = nOutboundBlockRelay < m_max_outbound_block_relay && !fFeeler && nOutboundFullRelay >= m_max_outbound_full_relay; + ConnectionType conn_type; + // Determine what type of connection to open. If fFeeler is not + // set, open OUTBOUND connections until we meet our full-relay + // capacity. Then open BLOCK_RELAY connections until we hit our + // block-relay peer limit. Otherwise, default to opening an + // OUTBOUND connection. + if (fFeeler) { + conn_type = ConnectionType::FEELER; + } else if (nOutboundFullRelay < m_max_outbound_full_relay) { + conn_type = ConnectionType::OUTBOUND; + } else if (nOutboundBlockRelay < m_max_outbound_block_relay) { + conn_type = ConnectionType::BLOCK_RELAY; + } else { + // GetTryNewOutboundPeer() is true + conn_type = ConnectionType::OUTBOUND; + } - OpenNetworkConnection(addrConnect, (int)setConnected.size() >= std::min(nMaxConnections - 1, 2), &grant, nullptr, false, fFeeler, false, block_relay_only); + OpenNetworkConnection(addrConnect, (int)setConnected.size() >= std::min(nMaxConnections - 1, 2), &grant, nullptr, conn_type); } } } @@ -1955,11 +1983,11 @@ std::vector<AddedNodeInfo> CConnman::GetAddedNodeInfo() LOCK(cs_vNodes); for (const CNode* pnode : vNodes) { if (pnode->addr.IsValid()) { - mapConnected[pnode->addr] = pnode->fInbound; + mapConnected[pnode->addr] = pnode->IsInboundConn(); } std::string addrName = pnode->GetAddrName(); if (!addrName.empty()) { - mapConnectedByName[std::move(addrName)] = std::make_pair(pnode->fInbound, static_cast<const CService&>(pnode->addr)); + mapConnectedByName[std::move(addrName)] = std::make_pair(pnode->IsInboundConn(), static_cast<const CService&>(pnode->addr)); } } } @@ -2006,7 +2034,7 @@ void CConnman::ThreadOpenAddedConnections() } tried = true; CAddress addr(CService(), NODE_NONE); - OpenNetworkConnection(addr, false, &grant, info.strAddedNode.c_str(), false, false, true); + OpenNetworkConnection(addr, false, &grant, info.strAddedNode.c_str(), ConnectionType::MANUAL); if (!interruptNet.sleep_for(std::chrono::milliseconds(500))) return; } @@ -2018,8 +2046,10 @@ void CConnman::ThreadOpenAddedConnections() } // if successful, this moves the passed grant to the constructed node -void CConnman::OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant *grantOutbound, const char *pszDest, bool fOneShot, bool fFeeler, bool manual_connection, bool block_relay_only) +void CConnman::OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant *grantOutbound, const char *pszDest, ConnectionType conn_type) { + assert(conn_type != ConnectionType::INBOUND); + // // Initiate outbound network connection // @@ -2030,25 +2060,19 @@ void CConnman::OpenNetworkConnection(const CAddress& addrConnect, bool fCountFai return; } if (!pszDest) { - if (IsLocal(addrConnect) || - FindNode(static_cast<CNetAddr>(addrConnect)) || (m_banman && m_banman->IsBanned(addrConnect)) || - FindNode(addrConnect.ToStringIPPort())) + bool banned_or_discouraged = m_banman && (m_banman->IsDiscouraged(addrConnect) || m_banman->IsBanned(addrConnect)); + if (IsLocal(addrConnect) || FindNode(static_cast<CNetAddr>(addrConnect)) || banned_or_discouraged || FindNode(addrConnect.ToStringIPPort())) { return; + } } else if (FindNode(std::string(pszDest))) return; - CNode* pnode = ConnectNode(addrConnect, pszDest, fCountFailure, manual_connection, block_relay_only); + CNode* pnode = ConnectNode(addrConnect, pszDest, fCountFailure, conn_type); if (!pnode) return; if (grantOutbound) grantOutbound->MoveTo(pnode->grantOutbound); - if (fOneShot) - pnode->fOneShot = true; - if (fFeeler) - pnode->fFeeler = true; - if (manual_connection) - pnode->m_manual_connection = true; m_msgproc->InitializeNode(pnode); { @@ -2106,11 +2130,6 @@ void CConnman::ThreadMessageHandler() } } - - - - - bool CConnman::BindListenPort(const CService& addrBind, bilingual_str& strError, NetPermissionFlags permissions) { int nOne = 1; @@ -2232,7 +2251,7 @@ void Discover() void CConnman::SetNetworkActive(bool active) { - LogPrint(BCLog::NET, "SetNetworkActive: %s\n", active); + LogPrintf("%s: %s\n", __func__, active); if (fNetworkActive == active) { return; @@ -2243,12 +2262,14 @@ void CConnman::SetNetworkActive(bool active) uiInterface.NotifyNetworkActiveChanged(fNetworkActive); } -CConnman::CConnman(uint64_t nSeed0In, uint64_t nSeed1In) : nSeed0(nSeed0In), nSeed1(nSeed1In) +CConnman::CConnman(uint64_t nSeed0In, uint64_t nSeed1In, bool network_active) + : nSeed0(nSeed0In), nSeed1(nSeed1In) { SetTryNewOutboundPeer(false); Options connOptions; Init(connOptions); + SetNetworkActive(network_active); } NodeId CConnman::GetNewNodeId() @@ -2314,7 +2335,7 @@ bool CConnman::Start(CScheduler& scheduler, const Options& connOptions) } for (const auto& strDest : connOptions.vSeedNodes) { - AddOneShot(strDest); + AddAddrFetch(strDest); } if (clientInterface) { @@ -2367,7 +2388,7 @@ bool CConnman::Start(CScheduler& scheduler, const Options& connOptions) else threadDNSAddressSeed = std::thread(&TraceThread<std::function<void()> >, "dnsseed", std::function<void()>(std::bind(&CConnman::ThreadDNSAddressSeed, this))); - // Initiate outbound connections from -addnode + // Initiate manual connections threadOpenAddedConnections = std::thread(&TraceThread<std::function<void()> >, "addcon", std::function<void()>(std::bind(&CConnman::ThreadOpenAddedConnections, this))); if (connOptions.m_use_addrman_outgoing && !connOptions.m_specified_outgoing.empty()) { @@ -2490,11 +2511,6 @@ CConnman::~CConnman() Stop(); } -size_t CConnman::GetAddressCount() const -{ - return addrman.size(); -} - void CConnman::SetServices(const CService &addr, ServiceFlags nServices) { addrman.SetServices(addr, nServices); @@ -2505,14 +2521,31 @@ void CConnman::MarkAddressGood(const CAddress& addr) addrman.Good(addr); } -void CConnman::AddNewAddresses(const std::vector<CAddress>& vAddr, const CAddress& addrFrom, int64_t nTimePenalty) +bool CConnman::AddNewAddresses(const std::vector<CAddress>& vAddr, const CAddress& addrFrom, int64_t nTimePenalty) { - addrman.Add(vAddr, addrFrom, nTimePenalty); + return addrman.Add(vAddr, addrFrom, nTimePenalty); } -std::vector<CAddress> CConnman::GetAddresses() +std::vector<CAddress> CConnman::GetAddresses(size_t max_addresses, size_t max_pct) { - return addrman.GetAddr(); + std::vector<CAddress> addresses = addrman.GetAddr(max_addresses, max_pct); + if (m_banman) { + addresses.erase(std::remove_if(addresses.begin(), addresses.end(), + [this](const CAddress& addr){return m_banman->IsDiscouraged(addr) || m_banman->IsBanned(addr);}), + addresses.end()); + } + return addresses; +} + +std::vector<CAddress> CConnman::GetAddresses(Network requestor_network, size_t max_addresses, size_t max_pct) +{ + const auto current_time = GetTime<std::chrono::microseconds>(); + if (m_addr_response_caches.find(requestor_network) == m_addr_response_caches.end() || + m_addr_response_caches[requestor_network].m_update_addr_response < current_time) { + m_addr_response_caches[requestor_network].m_addrs_response_cache = GetAddresses(max_addresses, max_pct); + m_addr_response_caches[requestor_network].m_update_addr_response = current_time + std::chrono::hours(21) + GetRandMillis(std::chrono::hours(6)); + } + return m_addr_response_caches[requestor_network].m_addrs_response_cache; } bool CConnman::AddNode(const std::string& strNode) @@ -2546,7 +2579,7 @@ size_t CConnman::GetNodeCount(NumConnections flags) int nNum = 0; for (const auto& pnode : vNodes) { - if (flags & (pnode->fInbound ? CONNECTIONS_IN : CONNECTIONS_OUT)) { + if (flags & (pnode->IsInboundConn() ? CONNECTIONS_IN : CONNECTIONS_OUT)) { nNum++; } } @@ -2624,7 +2657,7 @@ void CConnman::RecordBytesSent(uint64_t bytes) nMaxOutboundTotalBytesSentInCycle = 0; } - // TODO, exclude whitebind peers + // TODO, exclude peers with download permission nMaxOutboundTotalBytesSentInCycle += bytes; } @@ -2730,26 +2763,26 @@ int CConnman::GetBestHeight() const unsigned int CConnman::GetReceiveFloodSize() const { return nReceiveFloodSize; } -CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn, SOCKET hSocketIn, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, bool fInboundIn, bool block_relay_only) +CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn, SOCKET hSocketIn, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in) : nTimeConnected(GetSystemTimeInSeconds()), addr(addrIn), addrBind(addrBindIn), - fInbound(fInboundIn), nKeyedNetGroup(nKeyedNetGroupIn), // Don't relay addr messages to peers that we connect to as block-relay-only // peers (to prevent adversaries from inferring these links from addr // traffic). - m_addr_known{block_relay_only ? nullptr : MakeUnique<CRollingBloomFilter>(5000, 0.001)}, id(idIn), nLocalHostNonce(nLocalHostNonceIn), + m_conn_type(conn_type_in), nLocalServices(nLocalServicesIn), nMyStartingHeight(nMyStartingHeightIn) { hSocket = hSocketIn; addrName = addrNameIn == "" ? addr.ToStringIPPort() : addrNameIn; hashContinue = uint256(); - if (!block_relay_only) { + if (conn_type_in != ConnectionType::BLOCK_RELAY) { m_tx_relay = MakeUnique<TxRelay>(); + m_addr_known = MakeUnique<CRollingBloomFilter>(5000, 0.001); } for (const std::string &msg : getAllNetMessageTypes()) @@ -2779,7 +2812,7 @@ bool CConnman::NodeFullyConnected(const CNode* pnode) void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) { size_t nMessageSize = msg.data.size(); - LogPrint(BCLog::NET, "sending %s (%d bytes) peer=%d\n", SanitizeString(msg.command), nMessageSize, pnode->GetId()); + LogPrint(BCLog::NET, "sending %s (%d bytes) peer=%d\n", SanitizeString(msg.m_type), nMessageSize, pnode->GetId()); // make sure we use the appropriate network transport format std::vector<unsigned char> serializedHeader; @@ -2791,8 +2824,8 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) LOCK(pnode->cs_vSend); bool optimisticSend(pnode->vSendMsg.empty()); - //log total amount of bytes per command - pnode->mapSendBytesPerMsgCmd[msg.command] += nTotalSize; + //log total amount of bytes per message type + pnode->mapSendBytesPerMsgCmd[msg.m_type] += nTotalSize; pnode->nSendSize += nTotalSize; if (pnode->nSendSize > nSendBufferMaxSize) @@ -25,8 +25,9 @@ #include <uint256.h> #include <atomic> +#include <cstdint> #include <deque> -#include <stdint.h> +#include <map> #include <thread> #include <memory> #include <condition_variable> @@ -50,8 +51,8 @@ static const bool DEFAULT_WHITELISTFORCERELAY = false; static const int TIMEOUT_INTERVAL = 20 * 60; /** Run the feeler connection loop once every 2 minutes or 120 seconds. **/ static const int FEELER_INTERVAL = 120; -/** The maximum number of new addresses to accumulate before announcing. */ -static const unsigned int MAX_ADDR_TO_SEND = 1000; +/** The maximum number of addresses from our addrman to return in response to a getaddr message. */ +static constexpr size_t MAX_ADDR_TO_SEND = 1000; /** Maximum length of incoming protocol messages (no message over 4 MB is currently acceptable). */ static const unsigned int MAX_PROTOCOL_MESSAGE_LENGTH = 4 * 1000 * 1000; /** Maximum length of the user agent string in `version` message */ @@ -61,7 +62,7 @@ static const int MAX_OUTBOUND_FULL_RELAY_CONNECTIONS = 8; /** Maximum number of addnode outgoing nodes */ static const int MAX_ADDNODE_CONNECTIONS = 8; /** Maximum number of block-relay-only outgoing connections */ -static const int MAX_BLOCKS_ONLY_CONNECTIONS = 2; +static const int MAX_BLOCK_RELAY_ONLY_CONNECTIONS = 2; /** Maximum number of feeler connections */ static const int MAX_FEELER_CONNECTIONS = 1; /** -listen default */ @@ -110,9 +111,20 @@ struct CSerializedNetMsg CSerializedNetMsg& operator=(const CSerializedNetMsg&) = delete; std::vector<unsigned char> data; - std::string command; + std::string m_type; }; +/** Different types of connections to a peer. This enum encapsulates the + * information we have available at the time of opening or accepting the + * connection. Aside from INBOUND, all types are initiated by us. */ +enum class ConnectionType { + INBOUND, /**< peer initiated connections */ + OUTBOUND, /**< full relay connections (blocks, addrs, txns) made automatically. Addresses selected from AddrMan. */ + MANUAL, /**< connections to addresses added via addnode or the connect command line argument */ + FEELER, /**< short lived connections used to test address validity */ + BLOCK_RELAY, /**< only relay blocks to these automatic outbound connections. Addresses selected from AddrMan. */ + ADDR_FETCH, /**< short lived connections used to solicit addrs when starting the node without a populated AddrMan */ +}; class NetEventsInterface; class CConnman @@ -181,7 +193,7 @@ public: } } - CConnman(uint64_t seed0, uint64_t seed1); + CConnman(uint64_t seed0, uint64_t seed1, bool network_active = true); ~CConnman(); bool Start(CScheduler& scheduler, const Options& options); @@ -197,7 +209,7 @@ public: bool GetNetworkActive() const { return fNetworkActive; }; bool GetUseAddrmanOutgoing() const { return m_use_addrman_outgoing; }; void SetNetworkActive(bool active); - void OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant *grantOutbound = nullptr, const char *strDest = nullptr, bool fOneShot = false, bool fFeeler = false, bool manual_connection = false, bool block_relay_only = false); + void OpenNetworkConnection(const CAddress& addrConnect, bool fCountFailure, CSemaphoreGrant *grantOutbound = nullptr, const char *strDest = nullptr, ConnectionType conn_type = ConnectionType::OUTBOUND); bool CheckIncomingNonce(uint64_t nonce); bool ForNode(NodeId id, std::function<bool(CNode* pnode)> func); @@ -247,11 +259,17 @@ public: }; // Addrman functions - size_t GetAddressCount() const; void SetServices(const CService &addr, ServiceFlags nServices); void MarkAddressGood(const CAddress& addr); - void AddNewAddresses(const std::vector<CAddress>& vAddr, const CAddress& addrFrom, int64_t nTimePenalty = 0); - std::vector<CAddress> GetAddresses(); + bool AddNewAddresses(const std::vector<CAddress>& vAddr, const CAddress& addrFrom, int64_t nTimePenalty = 0); + std::vector<CAddress> GetAddresses(size_t max_addresses, size_t max_pct); + /** + * Cache is used to minimize topology leaks, so it should + * be used for all non-trusted calls, for example, p2p. + * A non-malicious call (from RPC or a peer with addr permission) should + * call the function without a parameter to avoid using the cache. + */ + std::vector<CAddress> GetAddresses(Network requestor_network, size_t max_addresses, size_t max_pct); // This allows temporarily exceeding m_max_outbound_full_relay, with the goal of finding // a peer that is better than all our current peers. @@ -341,8 +359,8 @@ private: bool Bind(const CService& addr, unsigned int flags, NetPermissionFlags permissions); bool InitBinds(const std::vector<CService>& binds, const std::vector<NetWhitebindPermissions>& whiteBinds); void ThreadOpenAddedConnections(); - void AddOneShot(const std::string& strDest); - void ProcessOneShot(); + void AddAddrFetch(const std::string& strDest); + void ProcessAddrFetch(); void ThreadOpenConnections(std::vector<std::string> connect); void ThreadMessageHandler(); void AcceptConnection(const ListenSocket& hListenSocket); @@ -363,7 +381,7 @@ private: CNode* FindNode(const CService& addr); bool AttemptToEvictConnection(); - CNode* ConnectNode(CAddress addrConnect, const char *pszDest, bool fCountFailure, bool manual_connection, bool block_relay_only); + CNode* ConnectNode(CAddress addrConnect, const char *pszDest, bool fCountFailure, ConnectionType conn_type); void AddWhitelistPermissionFlags(NetPermissionFlags& flags, const CNetAddr &addr) const; void DeleteNode(CNode* pnode); @@ -406,8 +424,8 @@ private: std::atomic<bool> fNetworkActive{true}; bool fAddressesInitialized{false}; CAddrMan addrman; - std::deque<std::string> vOneShots GUARDED_BY(cs_vOneShots); - RecursiveMutex cs_vOneShots; + std::deque<std::string> m_addr_fetches GUARDED_BY(m_addr_fetches_mutex); + RecursiveMutex m_addr_fetches_mutex; std::vector<std::string> vAddedNodes GUARDED_BY(cs_vAddedNodes); RecursiveMutex cs_vAddedNodes; std::vector<CNode*> vNodes GUARDED_BY(cs_vNodes); @@ -417,6 +435,29 @@ private: unsigned int nPrevNodeCount{0}; /** + * Cache responses to addr requests to minimize privacy leak. + * Attack example: scraping addrs in real-time may allow an attacker + * to infer new connections of the victim by detecting new records + * with fresh timestamps (per self-announcement). + */ + struct CachedAddrResponse { + std::vector<CAddress> m_addrs_response_cache; + std::chrono::microseconds m_update_addr_response{0}; + }; + + /** + * Addr responses stored in different caches + * per network prevent cross-network node identification. + * If a node for example is multi-homed under Tor and IPv6, + * a single cache (or no cache at all) would let an attacker + * to easily detect that it is the same node by comparing responses. + * The used memory equals to 1000 CAddress records (or around 32 bytes) per + * distinct Network (up to 5) we have/had an inbound peer from, + * resulting in at most ~160 KB. + */ + std::map<Network, CachedAddrResponse> m_addr_response_caches; + + /** * Services this instance offers. * * This data is replicated in each CNode instance we create during peer @@ -448,6 +489,7 @@ private: std::atomic<int> nBestHeight; CClientUIInterface* clientInterface; NetEventsInterface* m_msgproc; + /** Pointer to this node's banman. May be nullptr - check existence before dereferencing. */ BanMan* m_banman; /** SipHasher seeds for deterministic randomness */ @@ -482,7 +524,7 @@ void Discover(); void StartMapPort(); void InterruptMapPort(); void StopMapPort(); -unsigned short GetListenPort(); +uint16_t GetListenPort(); struct CombinerAll { @@ -612,13 +654,13 @@ public: */ class CNetMessage { public: - CDataStream m_recv; // received message data - int64_t m_time = 0; // time (in microseconds) of message receipt. + CDataStream m_recv; //!< received message data + std::chrono::microseconds m_time{0}; //!< time of message receipt bool m_valid_netmagic = false; bool m_valid_header = false; bool m_valid_checksum = false; - uint32_t m_message_size = 0; // size of the payload - uint32_t m_raw_message_size = 0; // used wire size of the message (including header/checksum) + uint32_t m_message_size{0}; //!< size of the payload + uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) std::string m_command; CNetMessage(CDataStream&& recv_in) : m_recv(std::move(recv_in)) {} @@ -642,7 +684,7 @@ public: // read and deserialize data virtual int Read(const char *data, unsigned int bytes) = 0; // decomposes a message from the context - virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time) = 0; + virtual CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) = 0; virtual ~TransportDeserializer() {} }; @@ -695,7 +737,7 @@ public: if (ret < 0) Reset(); return ret; } - CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, int64_t time) override; + CNetMessage GetMessage(const CMessageHeader::MessageStartChars& message_start, std::chrono::microseconds time) override; }; /** The TransportSerializer prepares messages for the network transport @@ -764,12 +806,8 @@ public: } // This boolean is unusued in actual processing, only present for backward compatibility at RPC/QT level bool m_legacyWhitelisted{false}; - bool fFeeler{false}; // If true this node is being used as a short lived feeler. - bool fOneShot{false}; - bool m_manual_connection{false}; bool fClient{false}; // set by version message bool m_limited_node{false}; //after BIP159, set by version message - const bool fInbound; std::atomic_bool fSuccessfullyConnected{false}; // Setting fDisconnect to true will cause the node to be disconnected the // next time DisconnectNodes() runs @@ -782,6 +820,60 @@ public: std::atomic_bool fPauseRecv{false}; std::atomic_bool fPauseSend{false}; + bool IsOutboundOrBlockRelayConn() const { + switch(m_conn_type) { + case ConnectionType::OUTBOUND: + case ConnectionType::BLOCK_RELAY: + return true; + case ConnectionType::INBOUND: + case ConnectionType::MANUAL: + case ConnectionType::ADDR_FETCH: + case ConnectionType::FEELER: + return false; + } + + assert(false); + } + + bool IsFullOutboundConn() const { + return m_conn_type == ConnectionType::OUTBOUND; + } + + bool IsManualConn() const { + return m_conn_type == ConnectionType::MANUAL; + } + + bool IsBlockOnlyConn() const { + return m_conn_type == ConnectionType::BLOCK_RELAY; + } + + bool IsFeelerConn() const { + return m_conn_type == ConnectionType::FEELER; + } + + bool IsAddrFetchConn() const { + return m_conn_type == ConnectionType::ADDR_FETCH; + } + + bool IsInboundConn() const { + return m_conn_type == ConnectionType::INBOUND; + } + + bool ExpectServicesFromConn() const { + switch(m_conn_type) { + case ConnectionType::INBOUND: + case ConnectionType::MANUAL: + case ConnectionType::FEELER: + return false; + case ConnectionType::OUTBOUND: + case ConnectionType::BLOCK_RELAY: + case ConnectionType::ADDR_FETCH: + return true; + } + + assert(false); + } + protected: mapMsgCmdSize mapSendBytesPerMsgCmd; mapMsgCmdSize mapRecvBytesPerMsgCmd GUARDED_BY(cs_vRecv); @@ -792,7 +884,7 @@ public: // flood relay std::vector<CAddress> vAddrToSend; - const std::unique_ptr<CRollingBloomFilter> m_addr_known; + std::unique_ptr<CRollingBloomFilter> m_addr_known = nullptr; bool fGetAddr{false}; std::chrono::microseconds m_next_addr_send GUARDED_BY(cs_sendProcessing){0}; std::chrono::microseconds m_next_local_addr_send GUARDED_BY(cs_sendProcessing){0}; @@ -803,7 +895,7 @@ public: // There is no final sorting before sending, as they are always sent immediately // and in the order requested. std::vector<uint256> vInventoryBlockToSend GUARDED_BY(cs_inventory); - RecursiveMutex cs_inventory; + Mutex cs_inventory; struct TxRelay { mutable RecursiveMutex cs_filter; @@ -845,8 +937,8 @@ public: // Ping time measurement: // The pong reply we're expecting, or 0 if no pong expected. std::atomic<uint64_t> nPingNonceSent{0}; - // Time (in usec) the last ping was sent, or 0 if no ping was ever sent. - std::atomic<int64_t> nPingUsecStart{0}; + /** When the last ping was sent, or 0 if no ping was ever sent */ + std::atomic<std::chrono::microseconds> m_ping_start{std::chrono::microseconds{0}}; // Last measured round-trip time. std::atomic<int64_t> nPingUsecTime{0}; // Best measured round-trip time. @@ -856,7 +948,7 @@ public: std::set<uint256> orphan_work_set; - CNode(NodeId id, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn, SOCKET hSocketIn, const CAddress &addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress &addrBindIn, const std::string &addrNameIn = "", bool fInboundIn = false, bool block_relay_only = false); + CNode(NodeId id, ServiceFlags nLocalServicesIn, int nMyStartingHeightIn, SOCKET hSocketIn, const CAddress &addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress &addrBindIn, const std::string &addrNameIn, ConnectionType conn_type_in); ~CNode(); CNode(const CNode&) = delete; CNode& operator=(const CNode&) = delete; @@ -864,6 +956,7 @@ public: private: const NodeId id; const uint64_t nLocalHostNonce; + const ConnectionType m_conn_type; //! Services offered to this peer. //! @@ -965,33 +1058,23 @@ public: } - void AddInventoryKnown(const CInv& inv) + void AddKnownTx(const uint256& hash) { if (m_tx_relay != nullptr) { LOCK(m_tx_relay->cs_tx_inventory); - m_tx_relay->filterInventoryKnown.insert(inv.hash); + m_tx_relay->filterInventoryKnown.insert(hash); } } - void PushInventory(const CInv& inv) + void PushTxInventory(const uint256& hash) { - if (inv.type == MSG_TX && m_tx_relay != nullptr) { - LOCK(m_tx_relay->cs_tx_inventory); - if (!m_tx_relay->filterInventoryKnown.contains(inv.hash)) { - m_tx_relay->setInventoryTxToSend.insert(inv.hash); - } - } else if (inv.type == MSG_BLOCK) { - LOCK(cs_inventory); - vInventoryBlockToSend.push_back(inv.hash); + if (m_tx_relay == nullptr) return; + LOCK(m_tx_relay->cs_tx_inventory); + if (!m_tx_relay->filterInventoryKnown.contains(hash)) { + m_tx_relay->setInventoryTxToSend.insert(hash); } } - void PushBlockHash(const uint256 &hash) - { - LOCK(cs_inventory); - vBlockHashesToAnnounce.push_back(hash); - } - void CloseSocketDisconnect(); void copyStats(CNodeStats &stats, const std::vector<bool> &m_asmap); diff --git a/src/net_permissions.cpp b/src/net_permissions.cpp index 22fa5ee73b..53648deb40 100644 --- a/src/net_permissions.cpp +++ b/src/net_permissions.cpp @@ -8,8 +8,20 @@ #include <util/system.h> #include <util/translation.h> +const std::vector<std::string> NET_PERMISSIONS_DOC{ + "bloomfilter (allow requesting BIP37 filtered blocks and transactions)", + "noban (do not ban for misbehavior; implies download)", + "forcerelay (relay transactions that are already in the mempool; implies relay)", + "relay (relay even in -blocksonly mode)", + "mempool (allow requesting BIP35 mempool contents)", + "download (allow getheaders during IBD, no disconnect after maxuploadtarget limit)", + "addr (responses to GETADDR avoid hitting the cache and contain random records with the most up-to-date info)" +}; + +namespace { + // The parse the following format "perm1,perm2@xxxxxx" -bool TryParsePermissionFlags(const std::string str, NetPermissionFlags& output, size_t& readen, std::string& error) +bool TryParsePermissionFlags(const std::string str, NetPermissionFlags& output, size_t& readen, bilingual_str& error) { NetPermissionFlags flags = PF_NONE; const auto atSeparator = str.find('@'); @@ -36,11 +48,13 @@ bool TryParsePermissionFlags(const std::string str, NetPermissionFlags& output, else if (permission == "noban") NetPermissions::AddFlag(flags, PF_NOBAN); else if (permission == "forcerelay") NetPermissions::AddFlag(flags, PF_FORCERELAY); else if (permission == "mempool") NetPermissions::AddFlag(flags, PF_MEMPOOL); + else if (permission == "download") NetPermissions::AddFlag(flags, PF_DOWNLOAD); else if (permission == "all") NetPermissions::AddFlag(flags, PF_ALL); else if (permission == "relay") NetPermissions::AddFlag(flags, PF_RELAY); + else if (permission == "addr") NetPermissions::AddFlag(flags, PF_ADDR); else if (permission.length() == 0); // Allow empty entries else { - error = strprintf(_("Invalid P2P permission: '%s'").translated, permission); + error = strprintf(_("Invalid P2P permission: '%s'"), permission); return false; } } @@ -48,10 +62,12 @@ bool TryParsePermissionFlags(const std::string str, NetPermissionFlags& output, } output = flags; - error = ""; + error = Untranslated(""); return true; } +} + std::vector<std::string> NetPermissions::ToStrings(NetPermissionFlags flags) { std::vector<std::string> strings; @@ -60,10 +76,12 @@ std::vector<std::string> NetPermissions::ToStrings(NetPermissionFlags flags) if (NetPermissions::HasFlag(flags, PF_FORCERELAY)) strings.push_back("forcerelay"); if (NetPermissions::HasFlag(flags, PF_RELAY)) strings.push_back("relay"); if (NetPermissions::HasFlag(flags, PF_MEMPOOL)) strings.push_back("mempool"); + if (NetPermissions::HasFlag(flags, PF_DOWNLOAD)) strings.push_back("download"); + if (NetPermissions::HasFlag(flags, PF_ADDR)) strings.push_back("addr"); return strings; } -bool NetWhitebindPermissions::TryParse(const std::string str, NetWhitebindPermissions& output, std::string& error) +bool NetWhitebindPermissions::TryParse(const std::string str, NetWhitebindPermissions& output, bilingual_str& error) { NetPermissionFlags flags; size_t offset; @@ -76,17 +94,17 @@ bool NetWhitebindPermissions::TryParse(const std::string str, NetWhitebindPermis return false; } if (addrBind.GetPort() == 0) { - error = strprintf(_("Need to specify a port with -whitebind: '%s'").translated, strBind); + error = strprintf(_("Need to specify a port with -whitebind: '%s'"), strBind); return false; } output.m_flags = flags; output.m_service = addrBind; - error = ""; + error = Untranslated(""); return true; } -bool NetWhitelistPermissions::TryParse(const std::string str, NetWhitelistPermissions& output, std::string& error) +bool NetWhitelistPermissions::TryParse(const std::string str, NetWhitelistPermissions& output, bilingual_str& error) { NetPermissionFlags flags; size_t offset; @@ -96,12 +114,12 @@ bool NetWhitelistPermissions::TryParse(const std::string str, NetWhitelistPermis CSubNet subnet; LookupSubNet(net, subnet); if (!subnet.IsValid()) { - error = strprintf(_("Invalid netmask specified in -whitelist: '%s'").translated, net); + error = strprintf(_("Invalid netmask specified in -whitelist: '%s'"), net); return false; } output.m_flags = flags; output.m_subnet = subnet; - error = ""; + error = Untranslated(""); return true; } diff --git a/src/net_permissions.h b/src/net_permissions.h index 962a2159fc..5b68f635a7 100644 --- a/src/net_permissions.h +++ b/src/net_permissions.h @@ -2,14 +2,19 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. +#include <netaddress.h> + #include <string> #include <vector> -#include <netaddress.h> #ifndef BITCOIN_NET_PERMISSIONS_H #define BITCOIN_NET_PERMISSIONS_H -enum NetPermissionFlags -{ + +struct bilingual_str; + +extern const std::vector<std::string> NET_PERMISSIONS_DOC; + +enum NetPermissionFlags { PF_NONE = 0, // Can query bloomfilter even if -peerbloomfilters is false PF_BLOOMFILTER = (1U << 1), @@ -18,15 +23,20 @@ enum NetPermissionFlags // Always relay transactions from this peer, even if already in mempool // Keep parameter interaction: forcerelay implies relay PF_FORCERELAY = (1U << 2) | PF_RELAY, - // Can't be banned for misbehavior - PF_NOBAN = (1U << 4), + // Allow getheaders during IBD and block-download after maxuploadtarget limit + PF_DOWNLOAD = (1U << 6), + // Can't be banned/disconnected/discouraged for misbehavior + PF_NOBAN = (1U << 4) | PF_DOWNLOAD, // Can query the mempool PF_MEMPOOL = (1U << 5), + // Can request addrs without hitting a privacy-preserving cache + PF_ADDR = (1U << 7), // True if the user did not specifically set fine grained permissions PF_ISIMPLICIT = (1U << 31), - PF_ALL = PF_BLOOMFILTER | PF_FORCERELAY | PF_RELAY | PF_NOBAN | PF_MEMPOOL, + PF_ALL = PF_BLOOMFILTER | PF_FORCERELAY | PF_RELAY | PF_NOBAN | PF_MEMPOOL | PF_DOWNLOAD | PF_ADDR, }; + class NetPermissions { public: @@ -45,17 +55,18 @@ public: flags = static_cast<NetPermissionFlags>(flags & ~f); } }; + class NetWhitebindPermissions : public NetPermissions { public: - static bool TryParse(const std::string str, NetWhitebindPermissions& output, std::string& error); + static bool TryParse(const std::string str, NetWhitebindPermissions& output, bilingual_str& error); CService m_service; }; class NetWhitelistPermissions : public NetPermissions { public: - static bool TryParse(const std::string str, NetWhitelistPermissions& output, std::string& error); + static bool TryParse(const std::string str, NetWhitelistPermissions& output, bilingual_str& error); CSubNet m_subnet; }; diff --git a/src/net_processing.cpp b/src/net_processing.cpp index 587b655d1c..7b83583e41 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -13,10 +13,9 @@ #include <consensus/validation.h> #include <hash.h> #include <index/blockfilterindex.h> -#include <validation.h> #include <merkleblock.h> -#include <netmessagemaker.h> #include <netbase.h> +#include <netmessagemaker.h> #include <policy/fees.h> #include <policy/policy.h> #include <primitives/block.h> @@ -26,22 +25,22 @@ #include <scheduler.h> #include <tinyformat.h> #include <txmempool.h> -#include <util/system.h> +#include <util/check.h> // For NDEBUG compile time check #include <util/strencodings.h> +#include <util/system.h> +#include <validation.h> #include <memory> #include <typeinfo> -#if defined(NDEBUG) -# error "Bitcoin cannot be compiled without assertions." -#endif - /** Expiration time for orphan transactions in seconds */ static constexpr int64_t ORPHAN_TX_EXPIRE_TIME = 20 * 60; /** Minimum time between orphan transactions expire time checks in seconds */ static constexpr int64_t ORPHAN_TX_EXPIRE_INTERVAL = 5 * 60; /** How long to cache transactions in mapRelay for normal relay */ -static constexpr std::chrono::seconds RELAY_TX_CACHE_TIME{15 * 60}; +static constexpr std::chrono::seconds RELAY_TX_CACHE_TIME = std::chrono::minutes{15}; +/** How long a transaction has to be in the mempool before it can unconditionally be relayed (even when not in mapRelay). */ +static constexpr std::chrono::seconds UNCONDITIONAL_RELAY_DELAY = std::chrono::minutes{2}; /** Headers download timeout expressed in microseconds * Timeout = base + per_header * (expected number of headers) */ static constexpr int64_t HEADERS_DOWNLOAD_TIMEOUT_BASE = 15 * 60 * 1000000; // 15 minutes @@ -66,8 +65,8 @@ static constexpr int STALE_RELAY_AGE_LIMIT = 30 * 24 * 60 * 60; /// Age after which a block is considered historical for purposes of rate /// limiting block relay. Set to one week, denominated in seconds. static constexpr int HISTORICAL_BLOCK_AGE = 7 * 24 * 60 * 60; -/** Time between pings automatically sent out for latency probing and keepalive (in seconds). */ -static const int PING_INTERVAL = 2 * 60; +/** Time between pings automatically sent out for latency probing and keepalive */ +static constexpr std::chrono::minutes PING_INTERVAL{2}; /** The maximum number of entries in a locator */ static const unsigned int MAX_LOCATOR_SZ = 101; /** The maximum number of entries in an 'inv' protocol message */ @@ -76,6 +75,8 @@ static const unsigned int MAX_INV_SZ = 50000; static constexpr int32_t MAX_PEER_TX_IN_FLIGHT = 100; /** Maximum number of announced transactions from a peer */ static constexpr int32_t MAX_PEER_TX_ANNOUNCEMENTS = 2 * MAX_INV_SZ; +/** How many microseconds to delay requesting transactions via txids, if we have wtxid-relaying peers */ +static constexpr std::chrono::microseconds TXID_RELAY_DELAY{std::chrono::seconds{2}}; /** How many microseconds to delay requesting transactions from inbound peers */ static constexpr std::chrono::microseconds INBOUND_PEER_TX_DELAY{std::chrono::seconds{2}}; /** How long to wait (in microseconds) before downloading a transaction from an additional peer */ @@ -120,11 +121,20 @@ static constexpr std::chrono::hours AVG_LOCAL_ADDRESS_BROADCAST_INTERVAL{24}; /** Average delay between peer address broadcasts */ static constexpr std::chrono::seconds AVG_ADDRESS_BROADCAST_INTERVAL{30}; /** Average delay between trickled inventory transmissions in seconds. - * Blocks and whitelisted receivers bypass this, outbound peers get half this delay. */ + * Blocks and peers with noban permission bypass this, outbound peers get half this delay. */ static const unsigned int INVENTORY_BROADCAST_INTERVAL = 5; -/** Maximum number of inventory items to send per transmission. +/** Maximum rate of inventory items to send per second. * Limits the impact of low-fee transaction floods. */ -static constexpr unsigned int INVENTORY_BROADCAST_MAX = 7 * INVENTORY_BROADCAST_INTERVAL; +static constexpr unsigned int INVENTORY_BROADCAST_PER_SECOND = 7; +/** Maximum number of inventory items to send per transmission. */ +static constexpr unsigned int INVENTORY_BROADCAST_MAX = INVENTORY_BROADCAST_PER_SECOND * INVENTORY_BROADCAST_INTERVAL; +/** The number of most recently announced transactions a peer can request. */ +static constexpr unsigned int INVENTORY_MAX_RECENT_RELAY = 3500; +/** Verify that INVENTORY_MAX_RECENT_RELAY is enough to cache everything typically + * relayed before unconditional relay from the mempool kicks in. This is only a + * lower bound, and it should be larger to account for higher inv rate to outbound + * peers, and random variations in the broadcast mechanism. */ +static_assert(INVENTORY_MAX_RECENT_RELAY >= INVENTORY_BROADCAST_PER_SECOND * UNCONDITIONAL_RELAY_DELAY / std::chrono::seconds{1}, "INVENTORY_RELAY_MAX too low"); /** Average delay between feefilter broadcasts in seconds. */ static constexpr unsigned int AVG_FEEFILTER_BROADCAST_INTERVAL = 10 * 60; /** Maximum feefilter broadcast delay after significant change. */ @@ -133,6 +143,8 @@ static constexpr unsigned int MAX_FEEFILTER_CHANGE_DELAY = 5 * 60; static constexpr uint32_t MAX_GETCFILTERS_SIZE = 1000; /** Maximum number of cf hashes that may be requested with one getcfheaders. See BIP 157. */ static constexpr uint32_t MAX_GETCFHEADERS_SIZE = 2000; +/** the maximum percentage of addresses from our addrman to return in response to a getaddr message. */ +static constexpr size_t MAX_PCT_ADDR_TO_SEND = 23; struct COrphanTx { // When modifying, adapt the copy of this definition in tests/DoS_tests. @@ -143,12 +155,10 @@ struct COrphanTx { }; RecursiveMutex g_cs_orphans; std::map<uint256, COrphanTx> mapOrphanTransactions GUARDED_BY(g_cs_orphans); +std::map<uint256, std::map<uint256, COrphanTx>::iterator> g_orphans_by_wtxid GUARDED_BY(g_cs_orphans); void EraseOrphansFor(NodeId peer); -/** Increase a node's misbehavior score. */ -void Misbehaving(NodeId nodeid, int howmuch, const std::string& message="") EXCLUSIVE_LOCKS_REQUIRED(cs_main); - // Internal stuff namespace { /** Number of nodes with fSyncStarted. */ @@ -179,6 +189,21 @@ namespace { * million to make it highly unlikely for users to have issues with this * filter. * + * We typically only add wtxids to this filter. For non-segwit + * transactions, the txid == wtxid, so this only prevents us from + * re-downloading non-segwit transactions when communicating with + * non-wtxidrelay peers -- which is important for avoiding malleation + * attacks that could otherwise interfere with transaction relay from + * non-wtxidrelay peers. For communicating with wtxidrelay peers, having + * the reject filter store wtxids is exactly what we want to avoid + * redownload of a rejected transaction. + * + * In cases where we can tell that a segwit transaction will fail + * validation no matter the witness, we may add the txid of such + * transaction to the filter as well. This can be helpful when + * communicating with txid-relay peers or if we were to otherwise fetch a + * transaction via txid (eg in our orphan handling). + * * Memory used: 1.3 MB */ std::unique_ptr<CRollingBloomFilter> recentRejects GUARDED_BY(cs_main); @@ -189,7 +214,7 @@ namespace { * We use this to avoid requesting transactions that have already been * confirnmed. */ - RecursiveMutex g_cs_recent_confirmed_transactions; + Mutex g_cs_recent_confirmed_transactions; std::unique_ptr<CRollingBloomFilter> g_recent_confirmed_transactions GUARDED_BY(g_cs_recent_confirmed_transactions); /** Blocks that are in flight, and that are in the queue to be downloaded. */ @@ -210,13 +235,16 @@ namespace { /** Number of peers from which we're downloading blocks. */ int nPeersWithValidatedDownloads GUARDED_BY(cs_main) = 0; + /** Number of peers with wtxid relay. */ + int g_wtxid_relay_peers GUARDED_BY(cs_main) = 0; + /** Number of outbound peers with m_chain_sync.m_protect. */ int g_outbound_peers_with_protect_from_disconnect GUARDED_BY(cs_main) = 0; /** When our tip was last updated. */ std::atomic<int64_t> g_last_tip_update(0); - /** Relay map */ + /** Relay map (txid or wtxid -> CTransactionRef) */ typedef std::map<uint256, CTransactionRef> MapRelay; MapRelay mapRelay GUARDED_BY(cs_main); /** Expiration-time ordered list of (expire time, relay map entry) pairs. */ @@ -252,8 +280,8 @@ struct CNodeState { bool fCurrentlyConnected; //! Accumulated misbehaviour score for this peer. int nMisbehavior; - //! Whether this peer should be disconnected and banned (unless whitelisted). - bool fShouldBan; + //! Whether this peer should be disconnected and marked as discouraged (unless it has the noban permission). + bool m_should_discourage; //! String name of this peer (debugging/logging purposes). const std::string name; //! The best known block we know this peer has announced. @@ -378,7 +406,7 @@ struct CNodeState { /* Track when to attempt download of announced transactions (process * time in micros -> txid) */ - std::multimap<std::chrono::microseconds, uint256> m_tx_process_time; + std::multimap<std::chrono::microseconds, GenTxid> m_tx_process_time; //! Store all the transactions a peer has recently announced std::set<uint256> m_tx_announced; @@ -398,13 +426,19 @@ struct CNodeState { //! Whether this peer is a manual connection bool m_is_manual_connection; + //! A rolling bloom filter of all announced tx CInvs to this peer. + CRollingBloomFilter m_recently_announced_invs = CRollingBloomFilter{INVENTORY_MAX_RECENT_RELAY, 0.000001}; + + //! Whether this peer relays txs via wtxid + bool m_wtxid_relay{false}; + CNodeState(CAddress addrIn, std::string addrNameIn, bool is_inbound, bool is_manual) : address(addrIn), name(std::move(addrNameIn)), m_is_inbound(is_inbound), m_is_manual_connection (is_manual) { fCurrentlyConnected = false; nMisbehavior = 0; - fShouldBan = false; + m_should_discourage = false; pindexBestKnownBlock = nullptr; hashLastUnknownBlock.SetNull(); pindexLastCommonBlock = nullptr; @@ -425,6 +459,7 @@ struct CNodeState { fSupportsDesiredCmpctVersion = false; m_chain_sync = { 0, nullptr, false, false }; m_last_block_announcement = 0; + m_recently_announced_invs.reset(); } }; @@ -441,32 +476,32 @@ static CNodeState *State(NodeId pnode) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { return &it->second; } -static void UpdatePreferredDownload(CNode* node, CNodeState* state) EXCLUSIVE_LOCKS_REQUIRED(cs_main) +static void UpdatePreferredDownload(const CNode& node, CNodeState* state) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { nPreferredDownload -= state->fPreferredDownload; // Whether this node should be marked as a preferred download node. - state->fPreferredDownload = (!node->fInbound || node->HasPermission(PF_NOBAN)) && !node->fOneShot && !node->fClient; + state->fPreferredDownload = (!node.IsInboundConn() || node.HasPermission(PF_NOBAN)) && !node.IsAddrFetchConn() && !node.fClient; nPreferredDownload += state->fPreferredDownload; } -static void PushNodeVersion(CNode *pnode, CConnman* connman, int64_t nTime) +static void PushNodeVersion(CNode& pnode, CConnman& connman, int64_t nTime) { // Note that pnode->GetLocalServices() is a reflection of the local // services we were offering when the CNode object was created for this // peer. - ServiceFlags nLocalNodeServices = pnode->GetLocalServices(); - uint64_t nonce = pnode->GetLocalNonce(); - int nNodeStartingHeight = pnode->GetMyStartingHeight(); - NodeId nodeid = pnode->GetId(); - CAddress addr = pnode->addr; + ServiceFlags nLocalNodeServices = pnode.GetLocalServices(); + uint64_t nonce = pnode.GetLocalNonce(); + int nNodeStartingHeight = pnode.GetMyStartingHeight(); + NodeId nodeid = pnode.GetId(); + CAddress addr = pnode.addr; CAddress addrYou = (addr.IsRoutable() && !IsProxy(addr) ? addr : CAddress(CService(), addr.nServices)); CAddress addrMe = CAddress(CService(), nLocalNodeServices); - connman->PushMessage(pnode, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERSION, PROTOCOL_VERSION, (uint64_t)nLocalNodeServices, nTime, addrYou, addrMe, - nonce, strSubVersion, nNodeStartingHeight, ::g_relay_txes && pnode->m_tx_relay != nullptr)); + connman.PushMessage(&pnode, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERSION, PROTOCOL_VERSION, (uint64_t)nLocalNodeServices, nTime, addrYou, addrMe, + nonce, strSubVersion, nNodeStartingHeight, ::g_relay_txes && pnode.m_tx_relay != nullptr)); if (fLogIPs) { LogPrint(BCLog::NET, "send version message: version %d, blocks=%d, us=%s, them=%s, peer=%d\n", PROTOCOL_VERSION, nNodeStartingHeight, addrMe.ToString(), addrYou.ToString(), nodeid); @@ -576,7 +611,7 @@ static void UpdateBlockAvailability(NodeId nodeid, const uint256 &hash) EXCLUSIV * lNodesAnnouncingHeaderAndIDs, and keeping that list under a certain size by * removing the first element if necessary. */ -static void MaybeSetPeerAsAnnouncingHeaderAndIDs(NodeId nodeid, CConnman* connman) EXCLUSIVE_LOCKS_REQUIRED(cs_main) +static void MaybeSetPeerAsAnnouncingHeaderAndIDs(NodeId nodeid, CConnman& connman) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { AssertLockHeld(cs_main); CNodeState* nodestate = State(nodeid); @@ -592,20 +627,20 @@ static void MaybeSetPeerAsAnnouncingHeaderAndIDs(NodeId nodeid, CConnman* connma return; } } - connman->ForNode(nodeid, [connman](CNode* pfrom){ + connman.ForNode(nodeid, [&connman](CNode* pfrom){ AssertLockHeld(cs_main); uint64_t nCMPCTBLOCKVersion = (pfrom->GetLocalServices() & NODE_WITNESS) ? 2 : 1; if (lNodesAnnouncingHeaderAndIDs.size() >= 3) { // As per BIP152, we only get 3 of our peers to announce // blocks using compact encodings. - connman->ForNode(lNodesAnnouncingHeaderAndIDs.front(), [connman, nCMPCTBLOCKVersion](CNode* pnodeStop){ + connman.ForNode(lNodesAnnouncingHeaderAndIDs.front(), [&connman, nCMPCTBLOCKVersion](CNode* pnodeStop){ AssertLockHeld(cs_main); - connman->PushMessage(pnodeStop, CNetMsgMaker(pnodeStop->GetSendVersion()).Make(NetMsgType::SENDCMPCT, /*fAnnounceUsingCMPCTBLOCK=*/false, nCMPCTBLOCKVersion)); + connman.PushMessage(pnodeStop, CNetMsgMaker(pnodeStop->GetSendVersion()).Make(NetMsgType::SENDCMPCT, /*fAnnounceUsingCMPCTBLOCK=*/false, nCMPCTBLOCKVersion)); return true; }); lNodesAnnouncingHeaderAndIDs.pop_front(); } - connman->PushMessage(pfrom, CNetMsgMaker(pfrom->GetSendVersion()).Make(NetMsgType::SENDCMPCT, /*fAnnounceUsingCMPCTBLOCK=*/true, nCMPCTBLOCKVersion)); + connman.PushMessage(pfrom, CNetMsgMaker(pfrom->GetSendVersion()).Make(NetMsgType::SENDCMPCT, /*fAnnounceUsingCMPCTBLOCK=*/true, nCMPCTBLOCKVersion)); lNodesAnnouncingHeaderAndIDs.push_back(pfrom->GetId()); return true; }); @@ -724,34 +759,34 @@ static void FindNextBlocksToDownload(NodeId nodeid, unsigned int count, std::vec } } -void EraseTxRequest(const uint256& txid) EXCLUSIVE_LOCKS_REQUIRED(cs_main) +void EraseTxRequest(const GenTxid& gtxid) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { - g_already_asked_for.erase(txid); + g_already_asked_for.erase(gtxid.GetHash()); } -std::chrono::microseconds GetTxRequestTime(const uint256& txid) EXCLUSIVE_LOCKS_REQUIRED(cs_main) +std::chrono::microseconds GetTxRequestTime(const GenTxid& gtxid) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { - auto it = g_already_asked_for.find(txid); + auto it = g_already_asked_for.find(gtxid.GetHash()); if (it != g_already_asked_for.end()) { return it->second; } return {}; } -void UpdateTxRequestTime(const uint256& txid, std::chrono::microseconds request_time) EXCLUSIVE_LOCKS_REQUIRED(cs_main) +void UpdateTxRequestTime(const GenTxid& gtxid, std::chrono::microseconds request_time) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { - auto it = g_already_asked_for.find(txid); + auto it = g_already_asked_for.find(gtxid.GetHash()); if (it == g_already_asked_for.end()) { - g_already_asked_for.insert(std::make_pair(txid, request_time)); + g_already_asked_for.insert(std::make_pair(gtxid.GetHash(), request_time)); } else { g_already_asked_for.update(it, request_time); } } -std::chrono::microseconds CalculateTxGetDataTime(const uint256& txid, std::chrono::microseconds current_time, bool use_inbound_delay) EXCLUSIVE_LOCKS_REQUIRED(cs_main) +std::chrono::microseconds CalculateTxGetDataTime(const GenTxid& gtxid, std::chrono::microseconds current_time, bool use_inbound_delay, bool use_txid_delay) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { std::chrono::microseconds process_time; - const auto last_request_time = GetTxRequestTime(txid); + const auto last_request_time = GetTxRequestTime(gtxid); // First time requesting this tx if (last_request_time.count() == 0) { process_time = current_time; @@ -764,26 +799,29 @@ std::chrono::microseconds CalculateTxGetDataTime(const uint256& txid, std::chron // We delay processing announcements from inbound peers if (use_inbound_delay) process_time += INBOUND_PEER_TX_DELAY; + // We delay processing announcements from peers that use txid-relay (instead of wtxid) + if (use_txid_delay) process_time += TXID_RELAY_DELAY; + return process_time; } -void RequestTx(CNodeState* state, const uint256& txid, std::chrono::microseconds current_time) EXCLUSIVE_LOCKS_REQUIRED(cs_main) +void RequestTx(CNodeState* state, const GenTxid& gtxid, std::chrono::microseconds current_time) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { CNodeState::TxDownloadState& peer_download_state = state->m_tx_download; if (peer_download_state.m_tx_announced.size() >= MAX_PEER_TX_ANNOUNCEMENTS || peer_download_state.m_tx_process_time.size() >= MAX_PEER_TX_ANNOUNCEMENTS || - peer_download_state.m_tx_announced.count(txid)) { + peer_download_state.m_tx_announced.count(gtxid.GetHash())) { // Too many queued announcements from this peer, or we already have // this announcement return; } - peer_download_state.m_tx_announced.insert(txid); + peer_download_state.m_tx_announced.insert(gtxid.GetHash()); // Calculate the time to try requesting this transaction. Use // fPreferredDownload as a proxy for outbound peers. - const auto process_time = CalculateTxGetDataTime(txid, current_time, !state->fPreferredDownload); + const auto process_time = CalculateTxGetDataTime(gtxid, current_time, !state->fPreferredDownload, !state->m_wtxid_relay && g_wtxid_relay_peers > 0); - peer_download_state.m_tx_process_time.emplace(process_time, txid); + peer_download_state.m_tx_process_time.emplace(process_time, gtxid); } } // namespace @@ -797,35 +835,29 @@ void UpdateLastBlockAnnounceTime(NodeId node, int64_t time_in_seconds) if (state) state->m_last_block_announcement = time_in_seconds; } -// Returns true for outbound peers, excluding manual connections, feelers, and -// one-shots. -static bool IsOutboundDisconnectionCandidate(const CNode *node) -{ - return !(node->fInbound || node->m_manual_connection || node->fFeeler || node->fOneShot); -} - void PeerLogicValidation::InitializeNode(CNode *pnode) { CAddress addr = pnode->addr; std::string addrName = pnode->GetAddrName(); NodeId nodeid = pnode->GetId(); { LOCK(cs_main); - mapNodeState.emplace_hint(mapNodeState.end(), std::piecewise_construct, std::forward_as_tuple(nodeid), std::forward_as_tuple(addr, std::move(addrName), pnode->fInbound, pnode->m_manual_connection)); + mapNodeState.emplace_hint(mapNodeState.end(), std::piecewise_construct, std::forward_as_tuple(nodeid), std::forward_as_tuple(addr, std::move(addrName), pnode->IsInboundConn(), pnode->IsManualConn())); } - if(!pnode->fInbound) - PushNodeVersion(pnode, connman, GetTime()); + if(!pnode->IsInboundConn()) + PushNodeVersion(*pnode, *connman, GetTime()); } void PeerLogicValidation::ReattemptInitialBroadcast(CScheduler& scheduler) const { - std::set<uint256> unbroadcast_txids = m_mempool.GetUnbroadcastTxs(); + std::map<uint256, uint256> unbroadcast_txids = m_mempool.GetUnbroadcastTxs(); - for (const uint256& txid : unbroadcast_txids) { + for (const auto& elem : unbroadcast_txids) { // Sanity check: all unbroadcast txns should exist in the mempool - if (m_mempool.exists(txid)) { - RelayTransaction(txid, *connman); + if (m_mempool.exists(elem.first)) { + LOCK(cs_main); + RelayTransaction(elem.first, elem.second, *connman); } else { - m_mempool.RemoveUnbroadcastTx(txid, true); + m_mempool.RemoveUnbroadcastTx(elem.first, true); } } @@ -857,6 +889,8 @@ void PeerLogicValidation::FinalizeNode(NodeId nodeid, bool& fUpdateConnectionTim assert(nPeersWithValidatedDownloads >= 0); g_outbound_peers_with_protect_from_disconnect -= state->m_chain_sync.m_protect; assert(g_outbound_peers_with_protect_from_disconnect >= 0); + g_wtxid_relay_peers -= state->m_wtxid_relay; + assert(g_wtxid_relay_peers >= 0); mapNodeState.erase(nodeid); @@ -866,6 +900,7 @@ void PeerLogicValidation::FinalizeNode(NodeId nodeid, bool& fUpdateConnectionTim assert(nPreferredDownload == 0); assert(nPeersWithValidatedDownloads == 0); assert(g_outbound_peers_with_protect_from_disconnect == 0); + assert(g_wtxid_relay_peers == 0); } LogPrint(BCLog::NET, "Cleared nodestate for peer=%d\n", nodeid); } @@ -924,6 +959,8 @@ bool AddOrphanTx(const CTransactionRef& tx, NodeId peer) EXCLUSIVE_LOCKS_REQUIRE auto ret = mapOrphanTransactions.emplace(hash, COrphanTx{tx, peer, GetTime() + ORPHAN_TX_EXPIRE_TIME, g_orphan_list.size()}); assert(ret.second); g_orphan_list.push_back(ret.first); + // Allow for lookups in the orphan pool by wtxid, as well as txid + g_orphans_by_wtxid.emplace(tx->GetWitnessHash(), ret.first); for (const CTxIn& txin : tx->vin) { mapOrphanTransactionsByPrev[txin.prevout].insert(ret.first); } @@ -960,6 +997,7 @@ int static EraseOrphanTx(uint256 hash) EXCLUSIVE_LOCKS_REQUIRED(g_cs_orphans) it_last->second.list_pos = old_pos; } g_orphan_list.pop_back(); + g_orphans_by_wtxid.erase(it->second.tx->GetWitnessHash()); mapOrphanTransactions.erase(it); return 1; @@ -1019,30 +1057,29 @@ unsigned int LimitOrphanTxSize(unsigned int nMaxOrphans) } /** - * Mark a misbehaving peer to be banned depending upon the value of `-banscore`. + * Increment peer's misbehavior score. If the new value >= DISCOURAGEMENT_THRESHOLD, mark the node + * to be discouraged, meaning the peer might be disconnected and added to the discouragement filter. */ -void Misbehaving(NodeId pnode, int howmuch, const std::string& message) EXCLUSIVE_LOCKS_REQUIRED(cs_main) +void Misbehaving(const NodeId pnode, const int howmuch, const std::string& message) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { - if (howmuch == 0) - return; + assert(howmuch > 0); - CNodeState *state = State(pnode); - if (state == nullptr) - return; + CNodeState* const state = State(pnode); + if (state == nullptr) return; state->nMisbehavior += howmuch; - int banscore = gArgs.GetArg("-banscore", DEFAULT_BANSCORE_THRESHOLD); - std::string message_prefixed = message.empty() ? "" : (": " + message); - if (state->nMisbehavior >= banscore && state->nMisbehavior - howmuch < banscore) + const std::string message_prefixed = message.empty() ? "" : (": " + message); + if (state->nMisbehavior >= DISCOURAGEMENT_THRESHOLD && state->nMisbehavior - howmuch < DISCOURAGEMENT_THRESHOLD) { - LogPrint(BCLog::NET, "%s: %s peer=%d (%d -> %d) BAN THRESHOLD EXCEEDED%s\n", __func__, state->name, pnode, state->nMisbehavior-howmuch, state->nMisbehavior, message_prefixed); - state->fShouldBan = true; - } else - LogPrint(BCLog::NET, "%s: %s peer=%d (%d -> %d)%s\n", __func__, state->name, pnode, state->nMisbehavior-howmuch, state->nMisbehavior, message_prefixed); + LogPrint(BCLog::NET, "Misbehaving: peer=%d (%d -> %d) DISCOURAGE THRESHOLD EXCEEDED%s\n", pnode, state->nMisbehavior - howmuch, state->nMisbehavior, message_prefixed); + state->m_should_discourage = true; + } else { + LogPrint(BCLog::NET, "Misbehaving: peer=%d (%d -> %d)%s\n", pnode, state->nMisbehavior - howmuch, state->nMisbehavior, message_prefixed); + } } /** - * Potentially ban a node based on the contents of a BlockValidationState object + * Potentially mark a node discouraged based on the contents of a BlockValidationState object * * @param[in] via_compact_block this bool is passed in because net_processing should * punish peers differently depending on whether the data was provided in a compact @@ -1072,7 +1109,7 @@ static bool MaybePunishNodeForBlock(NodeId nodeid, const BlockValidationState& s break; } - // Ban outbound (but not inbound) peers if on an invalid chain. + // Discourage outbound (but not inbound) peers if on an invalid chain. // Exempt HB compact block peers and manual connections. if (!via_compact_block && !node_state->m_is_inbound && !node_state->m_is_manual_connection) { Misbehaving(nodeid, 100, message); @@ -1107,7 +1144,7 @@ static bool MaybePunishNodeForBlock(NodeId nodeid, const BlockValidationState& s } /** - * Potentially ban a node based on the contents of a TxValidationState object + * Potentially disconnect and discourage a node based on the contents of a TxValidationState object * * @return Returns true if the peer was punished (probably disconnected) */ @@ -1125,10 +1162,12 @@ static bool MaybePunishNodeForTx(NodeId nodeid, const TxValidationState& state, } // Conflicting (but not necessarily invalid) data or different policy: case TxValidationResult::TX_RECENT_CONSENSUS_CHANGE: + case TxValidationResult::TX_INPUTS_NOT_STANDARD: case TxValidationResult::TX_NOT_STANDARD: case TxValidationResult::TX_MISSING_INPUTS: case TxValidationResult::TX_PREMATURE_SPEND: case TxValidationResult::TX_WITNESS_MUTATED: + case TxValidationResult::TX_WITNESS_STRIPPED: case TxValidationResult::TX_CONFLICT: case TxValidationResult::TX_MEMPOOL_POLICY: break; @@ -1169,14 +1208,15 @@ PeerLogicValidation::PeerLogicValidation(CConnman* connmanIn, BanMan* banman, CS recentRejects.reset(new CRollingBloomFilter(120000, 0.000001)); // Blocks don't typically have more than 4000 transactions, so this should - // be at least six blocks (~1 hr) worth of transactions that we can store. + // be at least six blocks (~1 hr) worth of transactions that we can store, + // inserting both a txid and wtxid for every observed transaction. // If the number of transactions appearing in a block goes up, or if we are // seeing getdata requests more than an hour after initial announcement, we // can increase this number. // The false positive rate of 1/1M should come out to less than 1 // transaction per day that would be inadvertently ignored (which is the // same probability that we have in the reject filter). - g_recent_confirmed_transactions.reset(new CRollingBloomFilter(24000, 0.000001)); + g_recent_confirmed_transactions.reset(new CRollingBloomFilter(48000, 0.000001)); const Consensus::Params& consensusParams = Params().GetConsensus(); // Stale tip checking and peer eviction are on two different timers, but we @@ -1232,6 +1272,9 @@ void PeerLogicValidation::BlockConnected(const std::shared_ptr<const CBlock>& pb LOCK(g_cs_recent_confirmed_transactions); for (const auto& ptx : pblock->vtx) { g_recent_confirmed_transactions->insert(ptx->GetHash()); + if (ptx->GetHash() != ptx->GetWitnessHash()) { + g_recent_confirmed_transactions->insert(ptx->GetWitnessHash()); + } } } } @@ -1328,9 +1371,10 @@ void PeerLogicValidation::UpdatedBlockTip(const CBlockIndex *pindexNew, const CB } // Relay inventory, but don't relay old inventory during initial block download. connman->ForEachNode([nNewHeight, &vHashes](CNode* pnode) { + LOCK(pnode->cs_inventory); if (nNewHeight > (pnode->nStartingHeight != -1 ? pnode->nStartingHeight - 2000 : 0)) { for (const uint256& hash : reverse_iterate(vHashes)) { - pnode->PushBlockHash(hash); + pnode->vBlockHashesToAnnounce.push_back(hash); } } }); @@ -1339,7 +1383,7 @@ void PeerLogicValidation::UpdatedBlockTip(const CBlockIndex *pindexNew, const CB } /** - * Handle invalid block rejection and consequent peer banning, maintain which + * Handle invalid block rejection and consequent peer discouragement, maintain which * peers announce compact blocks. */ void PeerLogicValidation::BlockChecked(const CBlock& block, const BlockValidationState& state) { @@ -1365,7 +1409,7 @@ void PeerLogicValidation::BlockChecked(const CBlock& block, const BlockValidatio !::ChainstateActive().IsInitialBlockDownload() && mapBlocksInFlight.count(hash) == mapBlocksInFlight.size()) { if (it != mapBlockSource.end()) { - MaybeSetPeerAsAnnouncingHeaderAndIDs(it->second.first, connman); + MaybeSetPeerAsAnnouncingHeaderAndIDs(it->second.first, *connman); } } if (it != mapBlockSource.end()) @@ -1384,6 +1428,7 @@ bool static AlreadyHave(const CInv& inv, const CTxMemPool& mempool) EXCLUSIVE_LO { case MSG_TX: case MSG_WITNESS_TX: + case MSG_WTX: { assert(recentRejects); if (::ChainActive().Tip()->GetBlockHash() != hashRecentRejectsChainTip) @@ -1398,7 +1443,11 @@ bool static AlreadyHave(const CInv& inv, const CTxMemPool& mempool) EXCLUSIVE_LO { LOCK(g_cs_orphans); - if (mapOrphanTransactions.count(inv.hash)) return true; + if (!inv.IsMsgWtx() && mapOrphanTransactions.count(inv.hash)) { + return true; + } else if (inv.IsMsgWtx() && g_orphans_by_wtxid.count(inv.hash)) { + return true; + } } { @@ -1406,8 +1455,7 @@ bool static AlreadyHave(const CInv& inv, const CTxMemPool& mempool) EXCLUSIVE_LO if (g_recent_confirmed_transactions->contains(inv.hash)) return true; } - return recentRejects->contains(inv.hash) || - mempool.exists(inv.hash); + return recentRejects->contains(inv.hash) || mempool.exists(ToGenTxid(inv)); } case MSG_BLOCK: case MSG_WITNESS_BLOCK: @@ -1417,12 +1465,17 @@ bool static AlreadyHave(const CInv& inv, const CTxMemPool& mempool) EXCLUSIVE_LO return true; } -void RelayTransaction(const uint256& txid, const CConnman& connman) +void RelayTransaction(const uint256& txid, const uint256& wtxid, const CConnman& connman) { - CInv inv(MSG_TX, txid); - connman.ForEachNode([&inv](CNode* pnode) + connman.ForEachNode([&txid, &wtxid](CNode* pnode) { - pnode->PushInventory(inv); + AssertLockHeld(cs_main); + CNodeState &state = *State(pnode->GetId()); + if (state.m_wtxid_relay) { + pnode->PushTxInventory(wtxid); + } else { + pnode->PushTxInventory(txid); + } }); } @@ -1441,7 +1494,7 @@ static void RelayAddress(const CAddress& addr, bool fReachable, const CConnman& assert(nRelayNodes <= best.size()); auto sortfunc = [&best, &hasher, nRelayNodes](CNode* pnode) { - if (pnode->nVersion >= CADDR_TIME_VERSION && pnode->IsAddrRelayPeer()) { + if (pnode->IsAddrRelayPeer()) { uint64_t hashKey = CSipHasher(hasher).Write(pnode->GetId()).Finalize(); for (unsigned int i = 0; i < nRelayNodes; i++) { if (hashKey > best[i].first) { @@ -1462,7 +1515,7 @@ static void RelayAddress(const CAddress& addr, bool fReachable, const CConnman& connman.ForEachNodeThen(std::move(sortfunc), std::move(pushfunc)); } -void static ProcessGetBlockData(CNode* pfrom, const CChainParams& chainparams, const CInv& inv, CConnman* connman) +void static ProcessGetBlockData(CNode& pfrom, const CChainParams& chainparams, const CInv& inv, CConnman& connman) { bool send = false; std::shared_ptr<const CBlock> a_recent_block; @@ -1504,28 +1557,30 @@ void static ProcessGetBlockData(CNode* pfrom, const CChainParams& chainparams, c if (pindex) { send = BlockRequestAllowed(pindex, consensusParams); if (!send) { - LogPrint(BCLog::NET, "%s: ignoring request from peer=%i for old block that isn't in the main chain\n", __func__, pfrom->GetId()); + LogPrint(BCLog::NET, "%s: ignoring request from peer=%i for old block that isn't in the main chain\n", __func__, pfrom.GetId()); } } - const CNetMsgMaker msgMaker(pfrom->GetSendVersion()); + const CNetMsgMaker msgMaker(pfrom.GetSendVersion()); // disconnect node in case we have reached the outbound limit for serving historical blocks - // never disconnect whitelisted nodes - if (send && connman->OutboundTargetReached(true) && ( ((pindexBestHeader != nullptr) && (pindexBestHeader->GetBlockTime() - pindex->GetBlockTime() > HISTORICAL_BLOCK_AGE)) || inv.type == MSG_FILTERED_BLOCK) && !pfrom->HasPermission(PF_NOBAN)) - { - LogPrint(BCLog::NET, "historical block serving limit reached, disconnect peer=%d\n", pfrom->GetId()); + if (send && + connman.OutboundTargetReached(true) && + (((pindexBestHeader != nullptr) && (pindexBestHeader->GetBlockTime() - pindex->GetBlockTime() > HISTORICAL_BLOCK_AGE)) || inv.type == MSG_FILTERED_BLOCK) && + !pfrom.HasPermission(PF_DOWNLOAD) // nodes with the download permission may exceed target + ) { + LogPrint(BCLog::NET, "historical block serving limit reached, disconnect peer=%d\n", pfrom.GetId()); //disconnect node - pfrom->fDisconnect = true; + pfrom.fDisconnect = true; send = false; } // Avoid leaking prune-height by never sending blocks below the NODE_NETWORK_LIMITED threshold - if (send && !pfrom->HasPermission(PF_NOBAN) && ( - (((pfrom->GetLocalServices() & NODE_NETWORK_LIMITED) == NODE_NETWORK_LIMITED) && ((pfrom->GetLocalServices() & NODE_NETWORK) != NODE_NETWORK) && (::ChainActive().Tip()->nHeight - pindex->nHeight > (int)NODE_NETWORK_LIMITED_MIN_BLOCKS + 2 /* add two blocks buffer extension for possible races */) ) + if (send && !pfrom.HasPermission(PF_NOBAN) && ( + (((pfrom.GetLocalServices() & NODE_NETWORK_LIMITED) == NODE_NETWORK_LIMITED) && ((pfrom.GetLocalServices() & NODE_NETWORK) != NODE_NETWORK) && (::ChainActive().Tip()->nHeight - pindex->nHeight > (int)NODE_NETWORK_LIMITED_MIN_BLOCKS + 2 /* add two blocks buffer extension for possible races */) ) )) { - LogPrint(BCLog::NET, "Ignore block request below NODE_NETWORK_LIMITED threshold from peer=%d\n", pfrom->GetId()); + LogPrint(BCLog::NET, "Ignore block request below NODE_NETWORK_LIMITED threshold from peer=%d\n", pfrom.GetId()); //disconnect node and prevent it from stalling (would otherwise wait for the missing block) - pfrom->fDisconnect = true; + pfrom.fDisconnect = true; send = false; } // Pruned nodes may have deleted the block, so check whether @@ -1542,7 +1597,7 @@ void static ProcessGetBlockData(CNode* pfrom, const CChainParams& chainparams, c if (!ReadRawBlockFromDisk(block_data, pindex, chainparams.MessageStart())) { assert(!"cannot load block from disk"); } - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::BLOCK, MakeSpan(block_data))); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::BLOCK, MakeSpan(block_data))); // Don't set pblock as we've sent the block } else { // Send block from disk @@ -1553,22 +1608,22 @@ void static ProcessGetBlockData(CNode* pfrom, const CChainParams& chainparams, c } if (pblock) { if (inv.type == MSG_BLOCK) - connman->PushMessage(pfrom, msgMaker.Make(SERIALIZE_TRANSACTION_NO_WITNESS, NetMsgType::BLOCK, *pblock)); + connman.PushMessage(&pfrom, msgMaker.Make(SERIALIZE_TRANSACTION_NO_WITNESS, NetMsgType::BLOCK, *pblock)); else if (inv.type == MSG_WITNESS_BLOCK) - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::BLOCK, *pblock)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::BLOCK, *pblock)); else if (inv.type == MSG_FILTERED_BLOCK) { bool sendMerkleBlock = false; CMerkleBlock merkleBlock; - if (pfrom->m_tx_relay != nullptr) { - LOCK(pfrom->m_tx_relay->cs_filter); - if (pfrom->m_tx_relay->pfilter) { + if (pfrom.m_tx_relay != nullptr) { + LOCK(pfrom.m_tx_relay->cs_filter); + if (pfrom.m_tx_relay->pfilter) { sendMerkleBlock = true; - merkleBlock = CMerkleBlock(*pblock, *pfrom->m_tx_relay->pfilter); + merkleBlock = CMerkleBlock(*pblock, *pfrom.m_tx_relay->pfilter); } } if (sendMerkleBlock) { - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::MERKLEBLOCK, merkleBlock)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::MERKLEBLOCK, merkleBlock)); // CMerkleBlock just contains hashes, so also push any transactions in the block the client did not see // This avoids hurting performance by pointlessly requiring a round-trip // Note that there is currently no way for a node to request any single transactions we didn't send here - @@ -1577,7 +1632,7 @@ void static ProcessGetBlockData(CNode* pfrom, const CChainParams& chainparams, c // however we MUST always provide at least what the remote peer needs typedef std::pair<unsigned int, uint256> PairType; for (PairType& pair : merkleBlock.vMatchedTxn) - connman->PushMessage(pfrom, msgMaker.Make(SERIALIZE_TRANSACTION_NO_WITNESS, NetMsgType::TX, *pblock->vtx[pair.first])); + connman.PushMessage(&pfrom, msgMaker.Make(SERIALIZE_TRANSACTION_NO_WITNESS, NetMsgType::TX, *pblock->vtx[pair.first])); } // else // no response @@ -1588,101 +1643,120 @@ void static ProcessGetBlockData(CNode* pfrom, const CChainParams& chainparams, c // they won't have a useful mempool to match against a compact block, // and we don't feel like constructing the object for them, so // instead we respond with the full, non-compact block. - bool fPeerWantsWitness = State(pfrom->GetId())->fWantsCmpctWitness; + bool fPeerWantsWitness = State(pfrom.GetId())->fWantsCmpctWitness; int nSendFlags = fPeerWantsWitness ? 0 : SERIALIZE_TRANSACTION_NO_WITNESS; if (CanDirectFetch(consensusParams) && pindex->nHeight >= ::ChainActive().Height() - MAX_CMPCTBLOCK_DEPTH) { if ((fPeerWantsWitness || !fWitnessesPresentInARecentCompactBlock) && a_recent_compact_block && a_recent_compact_block->header.GetHash() == pindex->GetBlockHash()) { - connman->PushMessage(pfrom, msgMaker.Make(nSendFlags, NetMsgType::CMPCTBLOCK, *a_recent_compact_block)); + connman.PushMessage(&pfrom, msgMaker.Make(nSendFlags, NetMsgType::CMPCTBLOCK, *a_recent_compact_block)); } else { CBlockHeaderAndShortTxIDs cmpctblock(*pblock, fPeerWantsWitness); - connman->PushMessage(pfrom, msgMaker.Make(nSendFlags, NetMsgType::CMPCTBLOCK, cmpctblock)); + connman.PushMessage(&pfrom, msgMaker.Make(nSendFlags, NetMsgType::CMPCTBLOCK, cmpctblock)); } } else { - connman->PushMessage(pfrom, msgMaker.Make(nSendFlags, NetMsgType::BLOCK, *pblock)); + connman.PushMessage(&pfrom, msgMaker.Make(nSendFlags, NetMsgType::BLOCK, *pblock)); } } } // Trigger the peer node to send a getblocks request for the next batch of inventory - if (inv.hash == pfrom->hashContinue) + if (inv.hash == pfrom.hashContinue) { - // Bypass PushInventory, this must send even if redundant, + // Send immediately. This must send even if redundant, // and we want it right after the last block so they don't // wait for other stuff first. std::vector<CInv> vInv; vInv.push_back(CInv(MSG_BLOCK, ::ChainActive().Tip()->GetBlockHash())); - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::INV, vInv)); - pfrom->hashContinue.SetNull(); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::INV, vInv)); + pfrom.hashContinue.SetNull(); } } } //! Determine whether or not a peer can request a transaction, and return it (or nullptr if not found or not allowed). -CTransactionRef static FindTxForGetData(CNode* peer, const uint256& txid, const std::chrono::seconds mempool_req, const std::chrono::seconds longlived_mempool_time) LOCKS_EXCLUDED(cs_main) +CTransactionRef static FindTxForGetData(const CNode& peer, const GenTxid& gtxid, const std::chrono::seconds mempool_req, const std::chrono::seconds now) LOCKS_EXCLUDED(cs_main) { - // Check if the requested transaction is so recent that we're just - // about to announce it to the peer; if so, they certainly shouldn't - // know we already have it. - { - LOCK(peer->m_tx_relay->cs_tx_inventory); - if (peer->m_tx_relay->setInventoryTxToSend.count(txid)) return {}; + auto txinfo = mempool.info(gtxid); + if (txinfo.tx) { + // If a TX could have been INVed in reply to a MEMPOOL request, + // or is older than UNCONDITIONAL_RELAY_DELAY, permit the request + // unconditionally. + if ((mempool_req.count() && txinfo.m_time <= mempool_req) || txinfo.m_time <= now - UNCONDITIONAL_RELAY_DELAY) { + return std::move(txinfo.tx); + } } { LOCK(cs_main); - // Look up transaction in relay pool - auto mi = mapRelay.find(txid); - if (mi != mapRelay.end()) return mi->second; - } - - auto txinfo = mempool.info(txid); - if (txinfo.tx) { - // To protect privacy, do not answer getdata using the mempool when - // that TX couldn't have been INVed in reply to a MEMPOOL request, - // or when it's too recent to have expired from mapRelay. - if ((mempool_req.count() && txinfo.m_time <= mempool_req) || txinfo.m_time <= longlived_mempool_time) { - return txinfo.tx; + // Otherwise, the transaction must have been announced recently. + if (State(peer.GetId())->m_recently_announced_invs.contains(gtxid.GetHash())) { + // If it was, it can be relayed from either the mempool... + if (txinfo.tx) return std::move(txinfo.tx); + // ... or the relay pool. + auto mi = mapRelay.find(gtxid.GetHash()); + if (mi != mapRelay.end()) return mi->second; } } return {}; } -void static ProcessGetData(CNode* pfrom, const CChainParams& chainparams, CConnman* connman, CTxMemPool& mempool, const std::atomic<bool>& interruptMsgProc) LOCKS_EXCLUDED(cs_main) +void static ProcessGetData(CNode& pfrom, const CChainParams& chainparams, CConnman& connman, CTxMemPool& mempool, const std::atomic<bool>& interruptMsgProc) LOCKS_EXCLUDED(cs_main) { AssertLockNotHeld(cs_main); - std::deque<CInv>::iterator it = pfrom->vRecvGetData.begin(); + std::deque<CInv>::iterator it = pfrom.vRecvGetData.begin(); std::vector<CInv> vNotFound; - const CNetMsgMaker msgMaker(pfrom->GetSendVersion()); + const CNetMsgMaker msgMaker(pfrom.GetSendVersion()); - // mempool entries added before this time have likely expired from mapRelay - const std::chrono::seconds longlived_mempool_time = GetTime<std::chrono::seconds>() - RELAY_TX_CACHE_TIME; + const std::chrono::seconds now = GetTime<std::chrono::seconds>(); // Get last mempool request time - const std::chrono::seconds mempool_req = pfrom->m_tx_relay != nullptr ? pfrom->m_tx_relay->m_last_mempool_req.load() + const std::chrono::seconds mempool_req = pfrom.m_tx_relay != nullptr ? pfrom.m_tx_relay->m_last_mempool_req.load() : std::chrono::seconds::min(); // Process as many TX items from the front of the getdata queue as // possible, since they're common and it's efficient to batch process // them. - while (it != pfrom->vRecvGetData.end() && (it->type == MSG_TX || it->type == MSG_WITNESS_TX)) { + while (it != pfrom.vRecvGetData.end() && it->IsGenTxMsg()) { if (interruptMsgProc) return; // The send buffer provides backpressure. If there's no space in // the buffer, pause processing until the next call. - if (pfrom->fPauseSend) break; + if (pfrom.fPauseSend) break; const CInv &inv = *it++; - if (pfrom->m_tx_relay == nullptr) { + if (pfrom.m_tx_relay == nullptr) { // Ignore GETDATA requests for transactions from blocks-only peers. continue; } - CTransactionRef tx = FindTxForGetData(pfrom, inv.hash, mempool_req, longlived_mempool_time); + CTransactionRef tx = FindTxForGetData(pfrom, ToGenTxid(inv), mempool_req, now); if (tx) { - int nSendFlags = (inv.type == MSG_TX ? SERIALIZE_TRANSACTION_NO_WITNESS : 0); - connman->PushMessage(pfrom, msgMaker.Make(nSendFlags, NetMsgType::TX, *tx)); - mempool.RemoveUnbroadcastTx(inv.hash); + // WTX and WITNESS_TX imply we serialize with witness + int nSendFlags = (inv.IsMsgTx() ? SERIALIZE_TRANSACTION_NO_WITNESS : 0); + connman.PushMessage(&pfrom, msgMaker.Make(nSendFlags, NetMsgType::TX, *tx)); + mempool.RemoveUnbroadcastTx(tx->GetHash()); + // As we're going to send tx, make sure its unconfirmed parents are made requestable. + std::vector<uint256> parent_ids_to_add; + { + LOCK(mempool.cs); + auto txiter = mempool.GetIter(tx->GetHash()); + if (txiter) { + const CTxMemPool::setEntries& parents = mempool.GetMemPoolParents(*txiter); + parent_ids_to_add.reserve(parents.size()); + for (CTxMemPool::txiter parent_iter : parents) { + if (parent_iter->GetTime() > now - UNCONDITIONAL_RELAY_DELAY) { + parent_ids_to_add.push_back(parent_iter->GetTx().GetHash()); + } + } + } + } + for (const uint256& parent_txid : parent_ids_to_add) { + // Relaying a transaction with a recent but unconfirmed parent. + if (WITH_LOCK(pfrom.m_tx_relay->cs_tx_inventory, return !pfrom.m_tx_relay->filterInventoryKnown.contains(parent_txid))) { + LOCK(cs_main); + State(pfrom.GetId())->m_recently_announced_invs.insert(parent_txid); + } + } } else { vNotFound.push_back(inv); } @@ -1690,7 +1764,7 @@ void static ProcessGetData(CNode* pfrom, const CChainParams& chainparams, CConnm // Only process one BLOCK item per call, since they're uncommon and can be // expensive to process. - if (it != pfrom->vRecvGetData.end() && !pfrom->fPauseSend) { + if (it != pfrom.vRecvGetData.end() && !pfrom.fPauseSend) { const CInv &inv = *it++; if (inv.type == MSG_BLOCK || inv.type == MSG_FILTERED_BLOCK || inv.type == MSG_CMPCT_BLOCK || inv.type == MSG_WITNESS_BLOCK) { ProcessGetBlockData(pfrom, chainparams, inv, connman); @@ -1699,7 +1773,7 @@ void static ProcessGetData(CNode* pfrom, const CChainParams& chainparams, CConnm // and continue processing the queue on the next call. } - pfrom->vRecvGetData.erase(pfrom->vRecvGetData.begin(), it); + pfrom.vRecvGetData.erase(pfrom.vRecvGetData.begin(), it); if (!vNotFound.empty()) { // Let the peer know that we didn't find what it asked for, so it doesn't @@ -1716,49 +1790,49 @@ void static ProcessGetData(CNode* pfrom, const CChainParams& chainparams, CConnm // In normal operation, we often send NOTFOUND messages for parents of // transactions that we relay; if a peer is missing a parent, they may // assume we have them and request the parents from us. - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::NOTFOUND, vNotFound)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::NOTFOUND, vNotFound)); } } -static uint32_t GetFetchFlags(CNode* pfrom) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { +static uint32_t GetFetchFlags(const CNode& pfrom) EXCLUSIVE_LOCKS_REQUIRED(cs_main) { uint32_t nFetchFlags = 0; - if ((pfrom->GetLocalServices() & NODE_WITNESS) && State(pfrom->GetId())->fHaveWitness) { + if ((pfrom.GetLocalServices() & NODE_WITNESS) && State(pfrom.GetId())->fHaveWitness) { nFetchFlags |= MSG_WITNESS_FLAG; } return nFetchFlags; } -inline void static SendBlockTransactions(const CBlock& block, const BlockTransactionsRequest& req, CNode* pfrom, CConnman* connman) { +inline void static SendBlockTransactions(const CBlock& block, const BlockTransactionsRequest& req, CNode& pfrom, CConnman& connman) { BlockTransactions resp(req); for (size_t i = 0; i < req.indexes.size(); i++) { if (req.indexes[i] >= block.vtx.size()) { LOCK(cs_main); - Misbehaving(pfrom->GetId(), 100, strprintf("Peer %d sent us a getblocktxn with out-of-bounds tx indices", pfrom->GetId())); + Misbehaving(pfrom.GetId(), 100, "getblocktxn with out-of-bounds tx indices"); return; } resp.txn[i] = block.vtx[req.indexes[i]]; } LOCK(cs_main); - const CNetMsgMaker msgMaker(pfrom->GetSendVersion()); - int nSendFlags = State(pfrom->GetId())->fWantsCmpctWitness ? 0 : SERIALIZE_TRANSACTION_NO_WITNESS; - connman->PushMessage(pfrom, msgMaker.Make(nSendFlags, NetMsgType::BLOCKTXN, resp)); + const CNetMsgMaker msgMaker(pfrom.GetSendVersion()); + int nSendFlags = State(pfrom.GetId())->fWantsCmpctWitness ? 0 : SERIALIZE_TRANSACTION_NO_WITNESS; + connman.PushMessage(&pfrom, msgMaker.Make(nSendFlags, NetMsgType::BLOCKTXN, resp)); } -bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, ChainstateManager& chainman, CTxMemPool& mempool, const std::vector<CBlockHeader>& headers, const CChainParams& chainparams, bool via_compact_block) +static void ProcessHeadersMessage(CNode& pfrom, CConnman& connman, ChainstateManager& chainman, CTxMemPool& mempool, const std::vector<CBlockHeader>& headers, const CChainParams& chainparams, bool via_compact_block) { - const CNetMsgMaker msgMaker(pfrom->GetSendVersion()); + const CNetMsgMaker msgMaker(pfrom.GetSendVersion()); size_t nCount = headers.size(); if (nCount == 0) { // Nothing interesting. Stop asking this peers for more headers. - return true; + return; } bool received_new_header = false; const CBlockIndex *pindexLast = nullptr; { LOCK(cs_main); - CNodeState *nodestate = State(pfrom->GetId()); + CNodeState *nodestate = State(pfrom.GetId()); // If this looks like it could be a block announcement (nCount < // MAX_BLOCKS_TO_ANNOUNCE), use special logic for handling headers that @@ -1770,28 +1844,28 @@ bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, ChainstateMan // nUnconnectingHeaders gets reset back to 0. if (!LookupBlockIndex(headers[0].hashPrevBlock) && nCount < MAX_BLOCKS_TO_ANNOUNCE) { nodestate->nUnconnectingHeaders++; - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexBestHeader), uint256())); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexBestHeader), uint256())); LogPrint(BCLog::NET, "received header %s: missing prev block %s, sending getheaders (%d) to end (peer=%d, nUnconnectingHeaders=%d)\n", headers[0].GetHash().ToString(), headers[0].hashPrevBlock.ToString(), pindexBestHeader->nHeight, - pfrom->GetId(), nodestate->nUnconnectingHeaders); + pfrom.GetId(), nodestate->nUnconnectingHeaders); // Set hashLastUnknownBlock for this peer, so that if we // eventually get the headers - even from a different peer - // we can use this peer to download. - UpdateBlockAvailability(pfrom->GetId(), headers.back().GetHash()); + UpdateBlockAvailability(pfrom.GetId(), headers.back().GetHash()); if (nodestate->nUnconnectingHeaders % MAX_UNCONNECTING_HEADERS == 0) { - Misbehaving(pfrom->GetId(), 20); + Misbehaving(pfrom.GetId(), 20, strprintf("%d non-connecting headers", nodestate->nUnconnectingHeaders)); } - return true; + return; } uint256 hashLastBlock; for (const CBlockHeader& header : headers) { if (!hashLastBlock.IsNull() && header.hashPrevBlock != hashLastBlock) { - Misbehaving(pfrom->GetId(), 20, "non-continuous headers sequence"); - return false; + Misbehaving(pfrom.GetId(), 20, "non-continuous headers sequence"); + return; } hashLastBlock = header.GetHash(); } @@ -1806,21 +1880,21 @@ bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, ChainstateMan BlockValidationState state; if (!chainman.ProcessNewBlockHeaders(headers, state, chainparams, &pindexLast)) { if (state.IsInvalid()) { - MaybePunishNodeForBlock(pfrom->GetId(), state, via_compact_block, "invalid header received"); - return false; + MaybePunishNodeForBlock(pfrom.GetId(), state, via_compact_block, "invalid header received"); + return; } } { LOCK(cs_main); - CNodeState *nodestate = State(pfrom->GetId()); + CNodeState *nodestate = State(pfrom.GetId()); if (nodestate->nUnconnectingHeaders > 0) { - LogPrint(BCLog::NET, "peer=%d: resetting nUnconnectingHeaders (%d -> 0)\n", pfrom->GetId(), nodestate->nUnconnectingHeaders); + LogPrint(BCLog::NET, "peer=%d: resetting nUnconnectingHeaders (%d -> 0)\n", pfrom.GetId(), nodestate->nUnconnectingHeaders); } nodestate->nUnconnectingHeaders = 0; assert(pindexLast); - UpdateBlockAvailability(pfrom->GetId(), pindexLast->GetBlockHash()); + UpdateBlockAvailability(pfrom.GetId(), pindexLast->GetBlockHash()); // From here, pindexBestKnownBlock should be guaranteed to be non-null, // because it is set in UpdateBlockAvailability. Some nullptr checks @@ -1834,8 +1908,8 @@ bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, ChainstateMan // Headers message had its maximum size; the peer may have more headers. // TODO: optimize: if pindexLast is an ancestor of ::ChainActive().Tip or pindexBestHeader, continue // from there instead. - LogPrint(BCLog::NET, "more getheaders (%d) to end to peer=%d (startheight:%d)\n", pindexLast->nHeight, pfrom->GetId(), pfrom->nStartingHeight); - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexLast), uint256())); + LogPrint(BCLog::NET, "more getheaders (%d) to end to peer=%d (startheight:%d)\n", pindexLast->nHeight, pfrom.GetId(), pfrom.nStartingHeight); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexLast), uint256())); } bool fCanDirectFetch = CanDirectFetch(chainparams.GetConsensus()); @@ -1848,7 +1922,7 @@ bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, ChainstateMan while (pindexWalk && !::ChainActive().Contains(pindexWalk) && vToFetch.size() <= MAX_BLOCKS_IN_TRANSIT_PER_PEER) { if (!(pindexWalk->nStatus & BLOCK_HAVE_DATA) && !mapBlocksInFlight.count(pindexWalk->GetBlockHash()) && - (!IsWitnessEnabled(pindexWalk->pprev, chainparams.GetConsensus()) || State(pfrom->GetId())->fHaveWitness)) { + (!IsWitnessEnabled(pindexWalk->pprev, chainparams.GetConsensus()) || State(pfrom.GetId())->fHaveWitness)) { // We don't have this block, and it's not yet in flight. vToFetch.push_back(pindexWalk); } @@ -1872,9 +1946,9 @@ bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, ChainstateMan } uint32_t nFetchFlags = GetFetchFlags(pfrom); vGetData.push_back(CInv(MSG_BLOCK | nFetchFlags, pindex->GetBlockHash())); - MarkBlockAsInFlight(mempool, pfrom->GetId(), pindex->GetBlockHash(), pindex); + MarkBlockAsInFlight(mempool, pfrom.GetId(), pindex->GetBlockHash(), pindex); LogPrint(BCLog::NET, "Requesting block %s from peer=%d\n", - pindex->GetBlockHash().ToString(), pfrom->GetId()); + pindex->GetBlockHash().ToString(), pfrom.GetId()); } if (vGetData.size() > 1) { LogPrint(BCLog::NET, "Downloading blocks toward %s (%d) via headers direct fetch\n", @@ -1885,7 +1959,7 @@ bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, ChainstateMan // In any case, we want to download using a compact block, not a regular one vGetData[0] = CInv(MSG_CMPCT_BLOCK, vGetData[0].hash); } - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETDATA, vGetData)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETDATA, vGetData)); } } } @@ -1896,37 +1970,37 @@ bool static ProcessHeadersMessage(CNode* pfrom, CConnman* connman, ChainstateMan // headers to fetch from this peer. if (nodestate->pindexBestKnownBlock && nodestate->pindexBestKnownBlock->nChainWork < nMinimumChainWork) { // This peer has too little work on their headers chain to help - // us sync -- disconnect if using an outbound slot (unless - // whitelisted or addnode). + // us sync -- disconnect if it is an outbound disconnection + // candidate. // Note: We compare their tip to nMinimumChainWork (rather than // ::ChainActive().Tip()) because we won't start block download // until we have a headers chain that has at least // nMinimumChainWork, even if a peer has a chain past our tip, // as an anti-DoS measure. - if (IsOutboundDisconnectionCandidate(pfrom)) { - LogPrintf("Disconnecting outbound peer %d -- headers chain has insufficient work\n", pfrom->GetId()); - pfrom->fDisconnect = true; + if (pfrom.IsOutboundOrBlockRelayConn()) { + LogPrintf("Disconnecting outbound peer %d -- headers chain has insufficient work\n", pfrom.GetId()); + pfrom.fDisconnect = true; } } } - if (!pfrom->fDisconnect && IsOutboundDisconnectionCandidate(pfrom) && nodestate->pindexBestKnownBlock != nullptr && pfrom->m_tx_relay != nullptr) { + if (!pfrom.fDisconnect && pfrom.IsOutboundOrBlockRelayConn() && nodestate->pindexBestKnownBlock != nullptr && pfrom.m_tx_relay != nullptr) { // If this is an outbound full-relay peer, check to see if we should protect // it from the bad/lagging chain logic. // Note that block-relay-only peers are already implicitly protected, so we // only consider setting m_protect for the full-relay peers. if (g_outbound_peers_with_protect_from_disconnect < MAX_OUTBOUND_PEERS_TO_PROTECT_FROM_DISCONNECT && nodestate->pindexBestKnownBlock->nChainWork >= ::ChainActive().Tip()->nChainWork && !nodestate->m_chain_sync.m_protect) { - LogPrint(BCLog::NET, "Protecting outbound peer=%d from eviction\n", pfrom->GetId()); + LogPrint(BCLog::NET, "Protecting outbound peer=%d from eviction\n", pfrom.GetId()); nodestate->m_chain_sync.m_protect = true; ++g_outbound_peers_with_protect_from_disconnect; } } } - return true; + return; } -void static ProcessOrphanTx(CConnman* connman, CTxMemPool& mempool, std::set<uint256>& orphan_work_set, std::list<CTransactionRef>& removed_txn) EXCLUSIVE_LOCKS_REQUIRED(cs_main, g_cs_orphans) +void static ProcessOrphanTx(CConnman& connman, CTxMemPool& mempool, std::set<uint256>& orphan_work_set, std::list<CTransactionRef>& removed_txn) EXCLUSIVE_LOCKS_REQUIRED(cs_main, g_cs_orphans) { AssertLockHeld(cs_main); AssertLockHeld(g_cs_orphans); @@ -1950,7 +2024,7 @@ void static ProcessOrphanTx(CConnman* connman, CTxMemPool& mempool, std::set<uin if (setMisbehaving.count(fromPeer)) continue; if (AcceptToMemoryPool(mempool, orphan_state, porphanTx, &removed_txn, false /* bypass_limits */, 0 /* nAbsurdFee */)) { LogPrint(BCLog::MEMPOOL, " accepted orphan tx %s\n", orphanHash.ToString()); - RelayTransaction(orphanHash, *connman); + RelayTransaction(orphanHash, porphanTx->GetWitnessHash(), connman); for (unsigned int i = 0; i < orphanTx.vout.size(); i++) { auto it_by_prev = mapOrphanTransactionsByPrev.find(COutPoint(orphanHash, i)); if (it_by_prev != mapOrphanTransactionsByPrev.end()) { @@ -1967,17 +2041,43 @@ void static ProcessOrphanTx(CConnman* connman, CTxMemPool& mempool, std::set<uin if (MaybePunishNodeForTx(fromPeer, orphan_state)) { setMisbehaving.insert(fromPeer); } - LogPrint(BCLog::MEMPOOL, " invalid orphan tx %s\n", orphanHash.ToString()); + LogPrint(BCLog::MEMPOOL, " invalid orphan tx %s from peer=%d. %s\n", + orphanHash.ToString(), + fromPeer, + orphan_state.ToString()); } // Has inputs but not accepted to mempool // Probably non-standard or insufficient fee LogPrint(BCLog::MEMPOOL, " removed orphan tx %s\n", orphanHash.ToString()); - if (!orphanTx.HasWitness() && orphan_state.GetResult() != TxValidationResult::TX_WITNESS_MUTATED) { - // Do not use rejection cache for witness transactions or - // witness-stripped transactions, as they can have been malleated. - // See https://github.com/bitcoin/bitcoin/issues/8279 for details. + if (orphan_state.GetResult() != TxValidationResult::TX_WITNESS_STRIPPED) { + // We can add the wtxid of this transaction to our reject filter. + // Do not add txids of witness transactions or witness-stripped + // transactions to the filter, as they can have been malleated; + // adding such txids to the reject filter would potentially + // interfere with relay of valid transactions from peers that + // do not support wtxid-based relay. See + // https://github.com/bitcoin/bitcoin/issues/8279 for details. + // We can remove this restriction (and always add wtxids to + // the filter even for witness stripped transactions) once + // wtxid-based relay is broadly deployed. + // See also comments in https://github.com/bitcoin/bitcoin/pull/18044#discussion_r443419034 + // for concerns around weakening security of unupgraded nodes + // if we start doing this too early. assert(recentRejects); - recentRejects->insert(orphanHash); + recentRejects->insert(orphanTx.GetWitnessHash()); + // If the transaction failed for TX_INPUTS_NOT_STANDARD, + // then we know that the witness was irrelevant to the policy + // failure, since this check depends only on the txid + // (the scriptPubKey being spent is covered by the txid). + // Add the txid to the reject filter to prevent repeated + // processing of this transaction in the event that child + // transactions are later received (resulting in + // parent-fetching by txid via the orphan-handling logic). + if (orphan_state.GetResult() == TxValidationResult::TX_INPUTS_NOT_STANDARD && orphanTx.GetWitnessHash() != orphanTx.GetHash()) { + // We only add the txid if it differs from the wtxid, to + // avoid wasting entries in the rolling bloom filter. + recentRejects->insert(orphanTx.GetHash()); + } } EraseOrphanTx(orphanHash); done = true; @@ -2202,37 +2302,33 @@ static void ProcessGetCFCheckPt(CNode& peer, CDataStream& vRecv, const CChainPar connman.PushMessage(&peer, std::move(msg)); } -bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRecv, int64_t nTimeReceived, const CChainParams& chainparams, ChainstateManager& chainman, CTxMemPool& mempool, CConnman* connman, BanMan* banman, const std::atomic<bool>& interruptMsgProc) +void ProcessMessage( + CNode& pfrom, + const std::string& msg_type, + CDataStream& vRecv, + const std::chrono::microseconds time_received, + const CChainParams& chainparams, + ChainstateManager& chainman, + CTxMemPool& mempool, + CConnman& connman, + BanMan* banman, + const std::atomic<bool>& interruptMsgProc) { - LogPrint(BCLog::NET, "received: %s (%u bytes) peer=%d\n", SanitizeString(msg_type), vRecv.size(), pfrom->GetId()); + LogPrint(BCLog::NET, "received: %s (%u bytes) peer=%d\n", SanitizeString(msg_type), vRecv.size(), pfrom.GetId()); if (gArgs.IsArgSet("-dropmessagestest") && GetRand(gArgs.GetArg("-dropmessagestest", 0)) == 0) { LogPrintf("dropmessagestest DROPPING RECV MESSAGE\n"); - return true; + return; } - if (!(pfrom->GetLocalServices() & NODE_BLOOM) && - (msg_type == NetMsgType::FILTERLOAD || - msg_type == NetMsgType::FILTERADD)) - { - if (pfrom->nVersion >= NO_BLOOM_VERSION) { - LOCK(cs_main); - Misbehaving(pfrom->GetId(), 100); - return false; - } else { - pfrom->fDisconnect = true; - return false; - } - } - if (msg_type == NetMsgType::VERSION) { // Each connection can only send one version message - if (pfrom->nVersion != 0) + if (pfrom.nVersion != 0) { LOCK(cs_main); - Misbehaving(pfrom->GetId(), 1); - return false; + Misbehaving(pfrom.GetId(), 1, "redundant version message"); + return; } int64_t nTime; @@ -2250,22 +2346,22 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec vRecv >> nVersion >> nServiceInt >> nTime >> addrMe; nSendVersion = std::min(nVersion, PROTOCOL_VERSION); nServices = ServiceFlags(nServiceInt); - if (!pfrom->fInbound) + if (!pfrom.IsInboundConn()) { - connman->SetServices(pfrom->addr, nServices); + connman.SetServices(pfrom.addr, nServices); } - if (!pfrom->fInbound && !pfrom->fFeeler && !pfrom->m_manual_connection && !HasAllDesirableServiceFlags(nServices)) + if (pfrom.ExpectServicesFromConn() && !HasAllDesirableServiceFlags(nServices)) { - LogPrint(BCLog::NET, "peer=%d does not offer the expected services (%08x offered, %08x expected); disconnecting\n", pfrom->GetId(), nServices, GetDesirableServiceFlags(nServices)); - pfrom->fDisconnect = true; - return false; + LogPrint(BCLog::NET, "peer=%d does not offer the expected services (%08x offered, %08x expected); disconnecting\n", pfrom.GetId(), nServices, GetDesirableServiceFlags(nServices)); + pfrom.fDisconnect = true; + return; } if (nVersion < MIN_PEER_PROTO_VERSION) { // disconnect from peers older than this proto version - LogPrint(BCLog::NET, "peer=%d using obsolete version %i; disconnecting\n", pfrom->GetId(), nVersion); - pfrom->fDisconnect = true; - return false; + LogPrint(BCLog::NET, "peer=%d using obsolete version %i; disconnecting\n", pfrom.GetId(), nVersion); + pfrom.fDisconnect = true; + return; } if (!vRecv.empty()) @@ -2281,145 +2377,145 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec if (!vRecv.empty()) vRecv >> fRelay; // Disconnect if we connected to ourself - if (pfrom->fInbound && !connman->CheckIncomingNonce(nNonce)) + if (pfrom.IsInboundConn() && !connman.CheckIncomingNonce(nNonce)) { - LogPrintf("connected to self at %s, disconnecting\n", pfrom->addr.ToString()); - pfrom->fDisconnect = true; - return true; + LogPrintf("connected to self at %s, disconnecting\n", pfrom.addr.ToString()); + pfrom.fDisconnect = true; + return; } - if (pfrom->fInbound && addrMe.IsRoutable()) + if (pfrom.IsInboundConn() && addrMe.IsRoutable()) { SeenLocal(addrMe); } // Be shy and don't send version until we hear - if (pfrom->fInbound) + if (pfrom.IsInboundConn()) PushNodeVersion(pfrom, connman, GetAdjustedTime()); - connman->PushMessage(pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERACK)); + if (nVersion >= WTXID_RELAY_VERSION) { + connman.PushMessage(&pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::WTXIDRELAY)); + } + + connman.PushMessage(&pfrom, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERACK)); - pfrom->nServices = nServices; - pfrom->SetAddrLocal(addrMe); + pfrom.nServices = nServices; + pfrom.SetAddrLocal(addrMe); { - LOCK(pfrom->cs_SubVer); - pfrom->cleanSubVer = cleanSubVer; + LOCK(pfrom.cs_SubVer); + pfrom.cleanSubVer = cleanSubVer; } - pfrom->nStartingHeight = nStartingHeight; + pfrom.nStartingHeight = nStartingHeight; // set nodes not relaying blocks and tx and not serving (parts) of the historical blockchain as "clients" - pfrom->fClient = (!(nServices & NODE_NETWORK) && !(nServices & NODE_NETWORK_LIMITED)); + pfrom.fClient = (!(nServices & NODE_NETWORK) && !(nServices & NODE_NETWORK_LIMITED)); // set nodes not capable of serving the complete blockchain history as "limited nodes" - pfrom->m_limited_node = (!(nServices & NODE_NETWORK) && (nServices & NODE_NETWORK_LIMITED)); + pfrom.m_limited_node = (!(nServices & NODE_NETWORK) && (nServices & NODE_NETWORK_LIMITED)); - if (pfrom->m_tx_relay != nullptr) { - LOCK(pfrom->m_tx_relay->cs_filter); - pfrom->m_tx_relay->fRelayTxes = fRelay; // set to true after we get the first filter* message + if (pfrom.m_tx_relay != nullptr) { + LOCK(pfrom.m_tx_relay->cs_filter); + pfrom.m_tx_relay->fRelayTxes = fRelay; // set to true after we get the first filter* message } // Change version - pfrom->SetSendVersion(nSendVersion); - pfrom->nVersion = nVersion; + pfrom.SetSendVersion(nSendVersion); + pfrom.nVersion = nVersion; if((nServices & NODE_WITNESS)) { LOCK(cs_main); - State(pfrom->GetId())->fHaveWitness = true; + State(pfrom.GetId())->fHaveWitness = true; } // Potentially mark this peer as a preferred download peer. { LOCK(cs_main); - UpdatePreferredDownload(pfrom, State(pfrom->GetId())); + UpdatePreferredDownload(pfrom, State(pfrom.GetId())); } - if (!pfrom->fInbound && pfrom->IsAddrRelayPeer()) + if (!pfrom.IsInboundConn() && pfrom.IsAddrRelayPeer()) { // Advertise our address if (fListen && !::ChainstateActive().IsInitialBlockDownload()) { - CAddress addr = GetLocalAddress(&pfrom->addr, pfrom->GetLocalServices()); + CAddress addr = GetLocalAddress(&pfrom.addr, pfrom.GetLocalServices()); FastRandomContext insecure_rand; if (addr.IsRoutable()) { LogPrint(BCLog::NET, "ProcessMessages: advertising address %s\n", addr.ToString()); - pfrom->PushAddress(addr, insecure_rand); - } else if (IsPeerAddrLocalGood(pfrom)) { + pfrom.PushAddress(addr, insecure_rand); + } else if (IsPeerAddrLocalGood(&pfrom)) { addr.SetIP(addrMe); LogPrint(BCLog::NET, "ProcessMessages: advertising address %s\n", addr.ToString()); - pfrom->PushAddress(addr, insecure_rand); + pfrom.PushAddress(addr, insecure_rand); } } // Get recent addresses - if (pfrom->fOneShot || pfrom->nVersion >= CADDR_TIME_VERSION || connman->GetAddressCount() < 1000) - { - connman->PushMessage(pfrom, CNetMsgMaker(nSendVersion).Make(NetMsgType::GETADDR)); - pfrom->fGetAddr = true; - } - connman->MarkAddressGood(pfrom->addr); + connman.PushMessage(&pfrom, CNetMsgMaker(nSendVersion).Make(NetMsgType::GETADDR)); + pfrom.fGetAddr = true; + connman.MarkAddressGood(pfrom.addr); } std::string remoteAddr; if (fLogIPs) - remoteAddr = ", peeraddr=" + pfrom->addr.ToString(); + remoteAddr = ", peeraddr=" + pfrom.addr.ToString(); LogPrint(BCLog::NET, "receive version message: %s: version %d, blocks=%d, us=%s, peer=%d%s\n", - cleanSubVer, pfrom->nVersion, - pfrom->nStartingHeight, addrMe.ToString(), pfrom->GetId(), + cleanSubVer, pfrom.nVersion, + pfrom.nStartingHeight, addrMe.ToString(), pfrom.GetId(), remoteAddr); int64_t nTimeOffset = nTime - GetTime(); - pfrom->nTimeOffset = nTimeOffset; - AddTimeData(pfrom->addr, nTimeOffset); + pfrom.nTimeOffset = nTimeOffset; + AddTimeData(pfrom.addr, nTimeOffset); // If the peer is old enough to have the old alert system, send it the final alert. - if (pfrom->nVersion <= 70012) { + if (pfrom.nVersion <= 70012) { CDataStream finalAlert(ParseHex("60010000000000000000000000ffffff7f00000000ffffff7ffeffff7f01ffffff7f00000000ffffff7f00ffffff7f002f555247454e543a20416c657274206b657920636f6d70726f6d697365642c2075706772616465207265717569726564004630440220653febd6410f470f6bae11cad19c48413becb1ac2c17f908fd0fd53bdc3abd5202206d0e9c96fe88d4a0f01ed9dedae2b6f9e00da94cad0fecaae66ecf689bf71b50"), SER_NETWORK, PROTOCOL_VERSION); - connman->PushMessage(pfrom, CNetMsgMaker(nSendVersion).Make("alert", finalAlert)); + connman.PushMessage(&pfrom, CNetMsgMaker(nSendVersion).Make("alert", finalAlert)); } // Feeler connections exist only to verify if address is online. - if (pfrom->fFeeler) { - assert(pfrom->fInbound == false); - pfrom->fDisconnect = true; + if (pfrom.IsFeelerConn()) { + pfrom.fDisconnect = true; } - return true; + return; } - if (pfrom->nVersion == 0) { + if (pfrom.nVersion == 0) { // Must have a version message before anything else LOCK(cs_main); - Misbehaving(pfrom->GetId(), 1); - return false; + Misbehaving(pfrom.GetId(), 1, "non-version message before version handshake"); + return; } // At this point, the outgoing message serialization version can't change. - const CNetMsgMaker msgMaker(pfrom->GetSendVersion()); + const CNetMsgMaker msgMaker(pfrom.GetSendVersion()); if (msg_type == NetMsgType::VERACK) { - pfrom->SetRecvVersion(std::min(pfrom->nVersion.load(), PROTOCOL_VERSION)); + pfrom.SetRecvVersion(std::min(pfrom.nVersion.load(), PROTOCOL_VERSION)); - if (!pfrom->fInbound) { + if (!pfrom.IsInboundConn()) { // Mark this node as currently connected, so we update its timestamp later. LOCK(cs_main); - State(pfrom->GetId())->fCurrentlyConnected = true; + State(pfrom.GetId())->fCurrentlyConnected = true; LogPrintf("New outbound peer connected: version: %d, blocks=%d, peer=%d%s (%s)\n", - pfrom->nVersion.load(), pfrom->nStartingHeight, - pfrom->GetId(), (fLogIPs ? strprintf(", peeraddr=%s", pfrom->addr.ToString()) : ""), - pfrom->m_tx_relay == nullptr ? "block-relay" : "full-relay"); + pfrom.nVersion.load(), pfrom.nStartingHeight, + pfrom.GetId(), (fLogIPs ? strprintf(", peeraddr=%s", pfrom.addr.ToString()) : ""), + pfrom.m_tx_relay == nullptr ? "block-relay" : "full-relay"); } - if (pfrom->nVersion >= SENDHEADERS_VERSION) { + if (pfrom.nVersion >= SENDHEADERS_VERSION) { // Tell our peer we prefer to receive headers rather than inv's // We send this to non-NODE NETWORK peers as well, because even // non-NODE NETWORK peers can announce blocks (such as pruning // nodes) - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::SENDHEADERS)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::SENDHEADERS)); } - if (pfrom->nVersion >= SHORT_IDS_BLOCKS_VERSION) { + if (pfrom.nVersion >= SHORT_IDS_BLOCKS_VERSION) { // Tell our peer we are willing to provide version 1 or 2 cmpctblocks // However, we do not request new block announcements using // cmpctblock messages. @@ -2427,37 +2523,53 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // they may wish to request compact blocks from us bool fAnnounceUsingCMPCTBLOCK = false; uint64_t nCMPCTBLOCKVersion = 2; - if (pfrom->GetLocalServices() & NODE_WITNESS) - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::SENDCMPCT, fAnnounceUsingCMPCTBLOCK, nCMPCTBLOCKVersion)); + if (pfrom.GetLocalServices() & NODE_WITNESS) + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::SENDCMPCT, fAnnounceUsingCMPCTBLOCK, nCMPCTBLOCKVersion)); nCMPCTBLOCKVersion = 1; - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::SENDCMPCT, fAnnounceUsingCMPCTBLOCK, nCMPCTBLOCKVersion)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::SENDCMPCT, fAnnounceUsingCMPCTBLOCK, nCMPCTBLOCKVersion)); } - pfrom->fSuccessfullyConnected = true; - return true; + pfrom.fSuccessfullyConnected = true; + return; + } + + // Feature negotiation of wtxidrelay should happen between VERSION and + // VERACK, to avoid relay problems from switching after a connection is up + if (msg_type == NetMsgType::WTXIDRELAY) { + if (pfrom.fSuccessfullyConnected) { + // Disconnect peers that send wtxidrelay message after VERACK; this + // must be negotiated between VERSION and VERACK. + pfrom.fDisconnect = true; + return; + } + if (pfrom.nVersion >= WTXID_RELAY_VERSION) { + LOCK(cs_main); + if (!State(pfrom.GetId())->m_wtxid_relay) { + State(pfrom.GetId())->m_wtxid_relay = true; + g_wtxid_relay_peers++; + } + } + return; } - if (!pfrom->fSuccessfullyConnected) { + if (!pfrom.fSuccessfullyConnected) { // Must have a verack message before anything else LOCK(cs_main); - Misbehaving(pfrom->GetId(), 1); - return false; + Misbehaving(pfrom.GetId(), 1, "non-verack message before version handshake"); + return; } if (msg_type == NetMsgType::ADDR) { std::vector<CAddress> vAddr; vRecv >> vAddr; - // Don't want addr from older versions unless seeding - if (pfrom->nVersion < CADDR_TIME_VERSION && connman->GetAddressCount() > 1000) - return true; - if (!pfrom->IsAddrRelayPeer()) { - return true; + if (!pfrom.IsAddrRelayPeer()) { + return; } - if (vAddr.size() > 1000) + if (vAddr.size() > MAX_ADDR_TO_SEND) { LOCK(cs_main); - Misbehaving(pfrom->GetId(), 20, strprintf("message addr size() = %u", vAddr.size())); - return false; + Misbehaving(pfrom.GetId(), 20, strprintf("addr message size = %u", vAddr.size())); + return; } // Store the new addresses @@ -2467,7 +2579,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec for (CAddress& addr : vAddr) { if (interruptMsgProc) - return true; + return; // We only bother storing full nodes, though this may include // things which we would not make an outbound connection to, in @@ -2477,53 +2589,56 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec if (addr.nTime <= 100000000 || addr.nTime > nNow + 10 * 60) addr.nTime = nNow - 5 * 24 * 60 * 60; - pfrom->AddAddressKnown(addr); - if (banman->IsBanned(addr)) continue; // Do not process banned addresses beyond remembering we received them + pfrom.AddAddressKnown(addr); + if (banman && (banman->IsDiscouraged(addr) || banman->IsBanned(addr))) { + // Do not process banned/discouraged addresses beyond remembering we received them + continue; + } bool fReachable = IsReachable(addr); - if (addr.nTime > nSince && !pfrom->fGetAddr && vAddr.size() <= 10 && addr.IsRoutable()) + if (addr.nTime > nSince && !pfrom.fGetAddr && vAddr.size() <= 10 && addr.IsRoutable()) { // Relay to a limited number of other nodes - RelayAddress(addr, fReachable, *connman); + RelayAddress(addr, fReachable, connman); } // Do not store addresses outside our network if (fReachable) vAddrOk.push_back(addr); } - connman->AddNewAddresses(vAddrOk, pfrom->addr, 2 * 60 * 60); + connman.AddNewAddresses(vAddrOk, pfrom.addr, 2 * 60 * 60); if (vAddr.size() < 1000) - pfrom->fGetAddr = false; - if (pfrom->fOneShot) - pfrom->fDisconnect = true; - return true; + pfrom.fGetAddr = false; + if (pfrom.IsAddrFetchConn()) + pfrom.fDisconnect = true; + return; } if (msg_type == NetMsgType::SENDHEADERS) { LOCK(cs_main); - State(pfrom->GetId())->fPreferHeaders = true; - return true; + State(pfrom.GetId())->fPreferHeaders = true; + return; } if (msg_type == NetMsgType::SENDCMPCT) { bool fAnnounceUsingCMPCTBLOCK = false; uint64_t nCMPCTBLOCKVersion = 0; vRecv >> fAnnounceUsingCMPCTBLOCK >> nCMPCTBLOCKVersion; - if (nCMPCTBLOCKVersion == 1 || ((pfrom->GetLocalServices() & NODE_WITNESS) && nCMPCTBLOCKVersion == 2)) { + if (nCMPCTBLOCKVersion == 1 || ((pfrom.GetLocalServices() & NODE_WITNESS) && nCMPCTBLOCKVersion == 2)) { LOCK(cs_main); // fProvidesHeaderAndIDs is used to "lock in" version of compact blocks we send (fWantsCmpctWitness) - if (!State(pfrom->GetId())->fProvidesHeaderAndIDs) { - State(pfrom->GetId())->fProvidesHeaderAndIDs = true; - State(pfrom->GetId())->fWantsCmpctWitness = nCMPCTBLOCKVersion == 2; + if (!State(pfrom.GetId())->fProvidesHeaderAndIDs) { + State(pfrom.GetId())->fProvidesHeaderAndIDs = true; + State(pfrom.GetId())->fWantsCmpctWitness = nCMPCTBLOCKVersion == 2; } - if (State(pfrom->GetId())->fWantsCmpctWitness == (nCMPCTBLOCKVersion == 2)) // ignore later version announces - State(pfrom->GetId())->fPreferHeaderAndIDs = fAnnounceUsingCMPCTBLOCK; - if (!State(pfrom->GetId())->fSupportsDesiredCmpctVersion) { - if (pfrom->GetLocalServices() & NODE_WITNESS) - State(pfrom->GetId())->fSupportsDesiredCmpctVersion = (nCMPCTBLOCKVersion == 2); + if (State(pfrom.GetId())->fWantsCmpctWitness == (nCMPCTBLOCKVersion == 2)) // ignore later version announces + State(pfrom.GetId())->fPreferHeaderAndIDs = fAnnounceUsingCMPCTBLOCK; + if (!State(pfrom.GetId())->fSupportsDesiredCmpctVersion) { + if (pfrom.GetLocalServices() & NODE_WITNESS) + State(pfrom.GetId())->fSupportsDesiredCmpctVersion = (nCMPCTBLOCKVersion == 2); else - State(pfrom->GetId())->fSupportsDesiredCmpctVersion = (nCMPCTBLOCKVersion == 1); + State(pfrom.GetId())->fSupportsDesiredCmpctVersion = (nCMPCTBLOCKVersion == 1); } } - return true; + return; } if (msg_type == NetMsgType::INV) { @@ -2532,17 +2647,18 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec if (vInv.size() > MAX_INV_SZ) { LOCK(cs_main); - Misbehaving(pfrom->GetId(), 20, strprintf("message inv size() = %u", vInv.size())); - return false; + Misbehaving(pfrom.GetId(), 20, strprintf("inv message size = %u", vInv.size())); + return; } // We won't accept tx inv's if we're in blocks-only mode, or this is a // block-relay-only peer - bool fBlocksOnly = !g_relay_txes || (pfrom->m_tx_relay == nullptr); + bool fBlocksOnly = !g_relay_txes || (pfrom.m_tx_relay == nullptr); - // Allow whitelisted peers to send data other than blocks in blocks only mode if whitelistrelay is true - if (pfrom->HasPermission(PF_RELAY)) + // Allow peers with relay permission to send data other than blocks in blocks only mode + if (pfrom.HasPermission(PF_RELAY)) { fBlocksOnly = false; + } LOCK(cs_main); @@ -2553,17 +2669,26 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec for (CInv &inv : vInv) { if (interruptMsgProc) - return true; + return; + + // Ignore INVs that don't match wtxidrelay setting. + // Note that orphan parent fetching always uses MSG_TX GETDATAs regardless of the wtxidrelay setting. + // This is fine as no INV messages are involved in that process. + if (State(pfrom.GetId())->m_wtxid_relay) { + if (inv.IsMsgTx()) continue; + } else { + if (inv.IsMsgWtx()) continue; + } bool fAlreadyHave = AlreadyHave(inv, mempool); - LogPrint(BCLog::NET, "got inv: %s %s peer=%d\n", inv.ToString(), fAlreadyHave ? "have" : "new", pfrom->GetId()); + LogPrint(BCLog::NET, "got inv: %s %s peer=%d\n", inv.ToString(), fAlreadyHave ? "have" : "new", pfrom.GetId()); - if (inv.type == MSG_TX) { + if (inv.IsMsgTx()) { inv.type |= nFetchFlags; } if (inv.type == MSG_BLOCK) { - UpdateBlockAvailability(pfrom->GetId(), inv.hash); + UpdateBlockAvailability(pfrom.GetId(), inv.hash); if (!fAlreadyHave && !fImporting && !fReindex && !mapBlocksInFlight.count(inv.hash)) { // Headers-first is the primary method of announcement on // the network. If a node fell back to sending blocks by inv, @@ -2573,23 +2698,23 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec best_block = &inv.hash; } } else { - pfrom->AddInventoryKnown(inv); + pfrom.AddKnownTx(inv.hash); if (fBlocksOnly) { - LogPrint(BCLog::NET, "transaction (%s) inv sent in violation of protocol, disconnecting peer=%d\n", inv.hash.ToString(), pfrom->GetId()); - pfrom->fDisconnect = true; - return true; - } else if (!fAlreadyHave && !fImporting && !fReindex && !::ChainstateActive().IsInitialBlockDownload()) { - RequestTx(State(pfrom->GetId()), inv.hash, current_time); + LogPrint(BCLog::NET, "transaction (%s) inv sent in violation of protocol, disconnecting peer=%d\n", inv.hash.ToString(), pfrom.GetId()); + pfrom.fDisconnect = true; + return; + } else if (!fAlreadyHave && !chainman.ActiveChainstate().IsInitialBlockDownload()) { + RequestTx(State(pfrom.GetId()), ToGenTxid(inv), current_time); } } } if (best_block != nullptr) { - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexBestHeader), *best_block)); - LogPrint(BCLog::NET, "getheaders (%d) %s to peer=%d\n", pindexBestHeader->nHeight, best_block->ToString(), pfrom->GetId()); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexBestHeader), *best_block)); + LogPrint(BCLog::NET, "getheaders (%d) %s to peer=%d\n", pindexBestHeader->nHeight, best_block->ToString(), pfrom.GetId()); } - return true; + return; } if (msg_type == NetMsgType::GETDATA) { @@ -2598,19 +2723,19 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec if (vInv.size() > MAX_INV_SZ) { LOCK(cs_main); - Misbehaving(pfrom->GetId(), 20, strprintf("message getdata size() = %u", vInv.size())); - return false; + Misbehaving(pfrom.GetId(), 20, strprintf("getdata message size = %u", vInv.size())); + return; } - LogPrint(BCLog::NET, "received getdata (%u invsz) peer=%d\n", vInv.size(), pfrom->GetId()); + LogPrint(BCLog::NET, "received getdata (%u invsz) peer=%d\n", vInv.size(), pfrom.GetId()); if (vInv.size() > 0) { - LogPrint(BCLog::NET, "received getdata for: %s peer=%d\n", vInv[0].ToString(), pfrom->GetId()); + LogPrint(BCLog::NET, "received getdata for: %s peer=%d\n", vInv[0].ToString(), pfrom.GetId()); } - pfrom->vRecvGetData.insert(pfrom->vRecvGetData.end(), vInv.begin(), vInv.end()); + pfrom.vRecvGetData.insert(pfrom.vRecvGetData.end(), vInv.begin(), vInv.end()); ProcessGetData(pfrom, chainparams, connman, mempool, interruptMsgProc); - return true; + return; } if (msg_type == NetMsgType::GETBLOCKS) { @@ -2619,9 +2744,9 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec vRecv >> locator >> hashStop; if (locator.vHave.size() > MAX_LOCATOR_SZ) { - LogPrint(BCLog::NET, "getblocks locator size %lld > %d, disconnect peer=%d\n", locator.vHave.size(), MAX_LOCATOR_SZ, pfrom->GetId()); - pfrom->fDisconnect = true; - return true; + LogPrint(BCLog::NET, "getblocks locator size %lld > %d, disconnect peer=%d\n", locator.vHave.size(), MAX_LOCATOR_SZ, pfrom.GetId()); + pfrom.fDisconnect = true; + return; } // We might have announced the currently-being-connected tip using a @@ -2652,7 +2777,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec if (pindex) pindex = ::ChainActive().Next(pindex); int nLimit = 500; - LogPrint(BCLog::NET, "getblocks %d to %s limit %d from peer=%d\n", (pindex ? pindex->nHeight : -1), hashStop.IsNull() ? "end" : hashStop.ToString(), nLimit, pfrom->GetId()); + LogPrint(BCLog::NET, "getblocks %d to %s limit %d from peer=%d\n", (pindex ? pindex->nHeight : -1), hashStop.IsNull() ? "end" : hashStop.ToString(), nLimit, pfrom.GetId()); for (; pindex; pindex = ::ChainActive().Next(pindex)) { if (pindex->GetBlockHash() == hashStop) @@ -2668,17 +2793,17 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec LogPrint(BCLog::NET, " getblocks stopping, pruned or too old block at %d %s\n", pindex->nHeight, pindex->GetBlockHash().ToString()); break; } - pfrom->PushInventory(CInv(MSG_BLOCK, pindex->GetBlockHash())); + WITH_LOCK(pfrom.cs_inventory, pfrom.vInventoryBlockToSend.push_back(pindex->GetBlockHash())); if (--nLimit <= 0) { // When this block is requested, we'll send an inv that'll // trigger the peer to getblocks the next batch of inventory. LogPrint(BCLog::NET, " getblocks stopping at limit %d %s\n", pindex->nHeight, pindex->GetBlockHash().ToString()); - pfrom->hashContinue = pindex->GetBlockHash(); + pfrom.hashContinue = pindex->GetBlockHash(); break; } } - return true; + return; } if (msg_type == NetMsgType::GETBLOCKTXN) { @@ -2694,15 +2819,15 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec } if (recent_block) { SendBlockTransactions(*recent_block, req, pfrom, connman); - return true; + return; } LOCK(cs_main); const CBlockIndex* pindex = LookupBlockIndex(req.blockhash); if (!pindex || !(pindex->nStatus & BLOCK_HAVE_DATA)) { - LogPrint(BCLog::NET, "Peer %d sent us a getblocktxn for a block we don't have\n", pfrom->GetId()); - return true; + LogPrint(BCLog::NET, "Peer %d sent us a getblocktxn for a block we don't have\n", pfrom.GetId()); + return; } if (pindex->nHeight < ::ChainActive().Height() - MAX_BLOCKTXN_DEPTH) { @@ -2713,13 +2838,13 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // might maliciously send lots of getblocktxn requests to trigger // expensive disk reads, because it will require the peer to // actually receive all the data read from disk over the network. - LogPrint(BCLog::NET, "Peer %d sent us a getblocktxn for a block > %i deep\n", pfrom->GetId(), MAX_BLOCKTXN_DEPTH); + LogPrint(BCLog::NET, "Peer %d sent us a getblocktxn for a block > %i deep\n", pfrom.GetId(), MAX_BLOCKTXN_DEPTH); CInv inv; - inv.type = State(pfrom->GetId())->fWantsCmpctWitness ? MSG_WITNESS_BLOCK : MSG_BLOCK; + inv.type = State(pfrom.GetId())->fWantsCmpctWitness ? MSG_WITNESS_BLOCK : MSG_BLOCK; inv.hash = req.blockhash; - pfrom->vRecvGetData.push_back(inv); + pfrom.vRecvGetData.push_back(inv); // The message processing loop will go around again (without pausing) and we'll respond then (without cs_main) - return true; + return; } CBlock block; @@ -2727,7 +2852,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec assert(ret); SendBlockTransactions(block, req, pfrom, connman); - return true; + return; } if (msg_type == NetMsgType::GETHEADERS) { @@ -2736,30 +2861,30 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec vRecv >> locator >> hashStop; if (locator.vHave.size() > MAX_LOCATOR_SZ) { - LogPrint(BCLog::NET, "getheaders locator size %lld > %d, disconnect peer=%d\n", locator.vHave.size(), MAX_LOCATOR_SZ, pfrom->GetId()); - pfrom->fDisconnect = true; - return true; + LogPrint(BCLog::NET, "getheaders locator size %lld > %d, disconnect peer=%d\n", locator.vHave.size(), MAX_LOCATOR_SZ, pfrom.GetId()); + pfrom.fDisconnect = true; + return; } LOCK(cs_main); - if (::ChainstateActive().IsInitialBlockDownload() && !pfrom->HasPermission(PF_NOBAN)) { - LogPrint(BCLog::NET, "Ignoring getheaders from peer=%d because node is in initial block download\n", pfrom->GetId()); - return true; + if (::ChainstateActive().IsInitialBlockDownload() && !pfrom.HasPermission(PF_DOWNLOAD)) { + LogPrint(BCLog::NET, "Ignoring getheaders from peer=%d because node is in initial block download\n", pfrom.GetId()); + return; } - CNodeState *nodestate = State(pfrom->GetId()); + CNodeState *nodestate = State(pfrom.GetId()); const CBlockIndex* pindex = nullptr; if (locator.IsNull()) { // If locator is null, return the hashStop block pindex = LookupBlockIndex(hashStop); if (!pindex) { - return true; + return; } if (!BlockRequestAllowed(pindex, chainparams.GetConsensus())) { - LogPrint(BCLog::NET, "%s: ignoring request from peer=%i for old block header that isn't in the main chain\n", __func__, pfrom->GetId()); - return true; + LogPrint(BCLog::NET, "%s: ignoring request from peer=%i for old block header that isn't in the main chain\n", __func__, pfrom.GetId()); + return; } } else @@ -2773,7 +2898,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // we must use CBlocks, as CBlockHeaders won't include the 0x00 nTx count at the end std::vector<CBlock> vHeaders; int nLimit = MAX_HEADERS_RESULTS; - LogPrint(BCLog::NET, "getheaders %d to %s from peer=%d\n", (pindex ? pindex->nHeight : -1), hashStop.IsNull() ? "end" : hashStop.ToString(), pfrom->GetId()); + LogPrint(BCLog::NET, "getheaders %d to %s from peer=%d\n", (pindex ? pindex->nHeight : -1), hashStop.IsNull() ? "end" : hashStop.ToString(), pfrom.GetId()); for (; pindex; pindex = ::ChainActive().Next(pindex)) { vHeaders.push_back(pindex->GetBlockHeader()); @@ -2793,67 +2918,104 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // will re-announce the new block via headers (or compact blocks again) // in the SendMessages logic. nodestate->pindexBestHeaderSent = pindex ? pindex : ::ChainActive().Tip(); - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::HEADERS, vHeaders)); - return true; + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::HEADERS, vHeaders)); + return; } if (msg_type == NetMsgType::TX) { // Stop processing the transaction early if // 1) We are in blocks only mode and peer has no relay permission // 2) This peer is a block-relay-only peer - if ((!g_relay_txes && !pfrom->HasPermission(PF_RELAY)) || (pfrom->m_tx_relay == nullptr)) + if ((!g_relay_txes && !pfrom.HasPermission(PF_RELAY)) || (pfrom.m_tx_relay == nullptr)) { - LogPrint(BCLog::NET, "transaction sent in violation of protocol peer=%d\n", pfrom->GetId()); - pfrom->fDisconnect = true; - return true; + LogPrint(BCLog::NET, "transaction sent in violation of protocol peer=%d\n", pfrom.GetId()); + pfrom.fDisconnect = true; + return; } CTransactionRef ptx; vRecv >> ptx; const CTransaction& tx = *ptx; - CInv inv(MSG_TX, tx.GetHash()); - pfrom->AddInventoryKnown(inv); + const uint256& txid = ptx->GetHash(); + const uint256& wtxid = ptx->GetWitnessHash(); LOCK2(cs_main, g_cs_orphans); + CNodeState* nodestate = State(pfrom.GetId()); + + const uint256& hash = nodestate->m_wtxid_relay ? wtxid : txid; + pfrom.AddKnownTx(hash); + if (nodestate->m_wtxid_relay && txid != wtxid) { + // Insert txid into filterInventoryKnown, even for + // wtxidrelay peers. This prevents re-adding of + // unconfirmed parents to the recently_announced + // filter, when a child tx is requested. See + // ProcessGetData(). + pfrom.AddKnownTx(txid); + } + TxValidationState state; - CNodeState* nodestate = State(pfrom->GetId()); - nodestate->m_tx_download.m_tx_announced.erase(inv.hash); - nodestate->m_tx_download.m_tx_in_flight.erase(inv.hash); - EraseTxRequest(inv.hash); + for (const GenTxid& gtxid : {GenTxid(false, txid), GenTxid(true, wtxid)}) { + nodestate->m_tx_download.m_tx_announced.erase(gtxid.GetHash()); + nodestate->m_tx_download.m_tx_in_flight.erase(gtxid.GetHash()); + EraseTxRequest(gtxid); + } std::list<CTransactionRef> lRemovedTxn; - if (!AlreadyHave(inv, mempool) && + // We do the AlreadyHave() check using wtxid, rather than txid - in the + // absence of witness malleation, this is strictly better, because the + // recent rejects filter may contain the wtxid but rarely contains + // the txid of a segwit transaction that has been rejected. + // In the presence of witness malleation, it's possible that by only + // doing the check with wtxid, we could overlook a transaction which + // was confirmed with a different witness, or exists in our mempool + // with a different witness, but this has limited downside: + // mempool validation does its own lookup of whether we have the txid + // already; and an adversary can already relay us old transactions + // (older than our recency filter) if trying to DoS us, without any need + // for witness malleation. + if (!AlreadyHave(CInv(MSG_WTX, wtxid), mempool) && AcceptToMemoryPool(mempool, state, ptx, &lRemovedTxn, false /* bypass_limits */, 0 /* nAbsurdFee */)) { mempool.check(&::ChainstateActive().CoinsTip()); - RelayTransaction(tx.GetHash(), *connman); + RelayTransaction(tx.GetHash(), tx.GetWitnessHash(), connman); for (unsigned int i = 0; i < tx.vout.size(); i++) { - auto it_by_prev = mapOrphanTransactionsByPrev.find(COutPoint(inv.hash, i)); + auto it_by_prev = mapOrphanTransactionsByPrev.find(COutPoint(txid, i)); if (it_by_prev != mapOrphanTransactionsByPrev.end()) { for (const auto& elem : it_by_prev->second) { - pfrom->orphan_work_set.insert(elem->first); + pfrom.orphan_work_set.insert(elem->first); } } } - pfrom->nLastTXTime = GetTime(); + pfrom.nLastTXTime = GetTime(); LogPrint(BCLog::MEMPOOL, "AcceptToMemoryPool: peer=%d: accepted %s (poolsz %u txn, %u kB)\n", - pfrom->GetId(), + pfrom.GetId(), tx.GetHash().ToString(), mempool.size(), mempool.DynamicMemoryUsage() / 1000); // Recursively process any orphan transactions that depended on this one - ProcessOrphanTx(connman, mempool, pfrom->orphan_work_set, lRemovedTxn); + ProcessOrphanTx(connman, mempool, pfrom.orphan_work_set, lRemovedTxn); } else if (state.GetResult() == TxValidationResult::TX_MISSING_INPUTS) { bool fRejectedParents = false; // It may be the case that the orphans parents have all been rejected + + // Deduplicate parent txids, so that we don't have to loop over + // the same parent txid more than once down below. + std::vector<uint256> unique_parents; + unique_parents.reserve(tx.vin.size()); for (const CTxIn& txin : tx.vin) { - if (recentRejects->contains(txin.prevout.hash)) { + // We start with all parents, and then remove duplicates below. + unique_parents.push_back(txin.prevout.hash); + } + std::sort(unique_parents.begin(), unique_parents.end()); + unique_parents.erase(std::unique(unique_parents.begin(), unique_parents.end()), unique_parents.end()); + for (const uint256& parent_txid : unique_parents) { + if (recentRejects->contains(parent_txid)) { fRejectedParents = true; break; } @@ -2862,12 +3024,17 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec uint32_t nFetchFlags = GetFetchFlags(pfrom); const auto current_time = GetTime<std::chrono::microseconds>(); - for (const CTxIn& txin : tx.vin) { - CInv _inv(MSG_TX | nFetchFlags, txin.prevout.hash); - pfrom->AddInventoryKnown(_inv); - if (!AlreadyHave(_inv, mempool)) RequestTx(State(pfrom->GetId()), _inv.hash, current_time); + for (const uint256& parent_txid : unique_parents) { + // Here, we only have the txid (and not wtxid) of the + // inputs, so we only request in txid mode, even for + // wtxidrelay peers. + // Eventually we should replace this with an improved + // protocol for getting all unconfirmed parents. + CInv _inv(MSG_TX | nFetchFlags, parent_txid); + pfrom.AddKnownTx(parent_txid); + if (!AlreadyHave(_inv, mempool)) RequestTx(State(pfrom.GetId()), ToGenTxid(_inv), current_time); } - AddOrphanTx(ptx, pfrom->GetId()); + AddOrphanTx(ptx, pfrom.GetId()); // DoS prevention: do not allow mapOrphanTransactions to grow unbounded (see CVE-2012-3789) unsigned int nMaxOrphanTx = (unsigned int)std::max((int64_t)0, gArgs.GetArg("-maxorphantx", DEFAULT_MAX_ORPHAN_TRANSACTIONS)); @@ -2879,15 +3046,41 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec LogPrint(BCLog::MEMPOOL, "not keeping orphan with rejected parents %s\n",tx.GetHash().ToString()); // We will continue to reject this tx since it has rejected // parents so avoid re-requesting it from other peers. + // Here we add both the txid and the wtxid, as we know that + // regardless of what witness is provided, we will not accept + // this, so we don't need to allow for redownload of this txid + // from any of our non-wtxidrelay peers. recentRejects->insert(tx.GetHash()); + recentRejects->insert(tx.GetWitnessHash()); } } else { - if (!tx.HasWitness() && state.GetResult() != TxValidationResult::TX_WITNESS_MUTATED) { - // Do not use rejection cache for witness transactions or - // witness-stripped transactions, as they can have been malleated. - // See https://github.com/bitcoin/bitcoin/issues/8279 for details. + if (state.GetResult() != TxValidationResult::TX_WITNESS_STRIPPED) { + // We can add the wtxid of this transaction to our reject filter. + // Do not add txids of witness transactions or witness-stripped + // transactions to the filter, as they can have been malleated; + // adding such txids to the reject filter would potentially + // interfere with relay of valid transactions from peers that + // do not support wtxid-based relay. See + // https://github.com/bitcoin/bitcoin/issues/8279 for details. + // We can remove this restriction (and always add wtxids to + // the filter even for witness stripped transactions) once + // wtxid-based relay is broadly deployed. + // See also comments in https://github.com/bitcoin/bitcoin/pull/18044#discussion_r443419034 + // for concerns around weakening security of unupgraded nodes + // if we start doing this too early. assert(recentRejects); - recentRejects->insert(tx.GetHash()); + recentRejects->insert(tx.GetWitnessHash()); + // If the transaction failed for TX_INPUTS_NOT_STANDARD, + // then we know that the witness was irrelevant to the policy + // failure, since this check depends only on the txid + // (the scriptPubKey being spent is covered by the txid). + // Add the txid to the reject filter to prevent repeated + // processing of this transaction in the event that child + // transactions are later received (resulting in + // parent-fetching by txid via the orphan-handling logic). + if (state.GetResult() == TxValidationResult::TX_INPUTS_NOT_STANDARD && tx.GetWitnessHash() != tx.GetHash()) { + recentRejects->insert(tx.GetHash()); + } if (RecursiveDynamicUsage(*ptx) < 100000) { AddToCompactExtraTransactions(ptx); } @@ -2895,16 +3088,16 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec AddToCompactExtraTransactions(ptx); } - if (pfrom->HasPermission(PF_FORCERELAY)) { - // Always relay transactions received from whitelisted peers, even + if (pfrom.HasPermission(PF_FORCERELAY)) { + // Always relay transactions received from peers with forcerelay permission, even // if they were already in the mempool, // allowing the node to function as a gateway for // nodes hidden behind it. if (!mempool.exists(tx.GetHash())) { - LogPrintf("Not relaying non-mempool transaction %s from whitelisted peer=%d\n", tx.GetHash().ToString(), pfrom->GetId()); + LogPrintf("Not relaying non-mempool transaction %s from forcerelay peer=%d\n", tx.GetHash().ToString(), pfrom.GetId()); } else { - LogPrintf("Force relaying tx %s from whitelisted peer=%d\n", tx.GetHash().ToString(), pfrom->GetId()); - RelayTransaction(tx.GetHash(), *connman); + LogPrintf("Force relaying tx %s from peer=%d\n", tx.GetHash().ToString(), pfrom.GetId()); + RelayTransaction(tx.GetHash(), tx.GetWitnessHash(), connman); } } } @@ -2929,22 +3122,21 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // peer simply for relaying a tx that our recentRejects has caught, // regardless of false positives. - if (state.IsInvalid()) - { + if (state.IsInvalid()) { LogPrint(BCLog::MEMPOOLREJ, "%s from peer=%d was not accepted: %s\n", tx.GetHash().ToString(), - pfrom->GetId(), + pfrom.GetId(), state.ToString()); - MaybePunishNodeForTx(pfrom->GetId(), state); + MaybePunishNodeForTx(pfrom.GetId(), state); } - return true; + return; } if (msg_type == NetMsgType::CMPCTBLOCK) { // Ignore cmpctblock received while importing if (fImporting || fReindex) { - LogPrint(BCLog::NET, "Unexpected cmpctblock message received from peer %d\n", pfrom->GetId()); - return true; + LogPrint(BCLog::NET, "Unexpected cmpctblock message received from peer %d\n", pfrom.GetId()); + return; } CBlockHeaderAndShortTxIDs cmpctblock; @@ -2958,8 +3150,8 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec if (!LookupBlockIndex(cmpctblock.header.hashPrevBlock)) { // Doesn't connect (or is genesis), instead of DoSing in AcceptBlockHeader, request deeper headers if (!::ChainstateActive().IsInitialBlockDownload()) - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexBestHeader), uint256())); - return true; + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(pindexBestHeader), uint256())); + return; } if (!LookupBlockIndex(cmpctblock.header.GetHash())) { @@ -2971,8 +3163,8 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec BlockValidationState state; if (!chainman.ProcessNewBlockHeaders({cmpctblock.header}, state, chainparams, &pindex)) { if (state.IsInvalid()) { - MaybePunishNodeForBlock(pfrom->GetId(), state, /*via_compact_block*/ true, "invalid header via cmpctblock"); - return true; + MaybePunishNodeForBlock(pfrom.GetId(), state, /*via_compact_block*/ true, "invalid header via cmpctblock"); + return; } } @@ -2996,9 +3188,9 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec LOCK2(cs_main, g_cs_orphans); // If AcceptBlockHeader returned true, it set pindex assert(pindex); - UpdateBlockAvailability(pfrom->GetId(), pindex->GetBlockHash()); + UpdateBlockAvailability(pfrom.GetId(), pindex->GetBlockHash()); - CNodeState *nodestate = State(pfrom->GetId()); + CNodeState *nodestate = State(pfrom.GetId()); // If this was a new header with more work than our tip, update the // peer's last block announcement time @@ -3010,7 +3202,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec bool fAlreadyInFlight = blockInFlightIt != mapBlocksInFlight.end(); if (pindex->nStatus & BLOCK_HAVE_DATA) // Nothing to do here - return true; + return; if (pindex->nChainWork <= ::ChainActive().Tip()->nChainWork || // We know something better pindex->nTx != 0) { // We had this block at some point, but pruned it @@ -3019,49 +3211,49 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // so we just grab the block via normal getdata std::vector<CInv> vInv(1); vInv[0] = CInv(MSG_BLOCK | GetFetchFlags(pfrom), cmpctblock.header.GetHash()); - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETDATA, vInv)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETDATA, vInv)); } - return true; + return; } // If we're not close to tip yet, give up and let parallel block fetch work its magic if (!fAlreadyInFlight && !CanDirectFetch(chainparams.GetConsensus())) - return true; + return; if (IsWitnessEnabled(pindex->pprev, chainparams.GetConsensus()) && !nodestate->fSupportsDesiredCmpctVersion) { // Don't bother trying to process compact blocks from v1 peers // after segwit activates. - return true; + return; } // We want to be a bit conservative just to be extra careful about DoS // possibilities in compact block processing... if (pindex->nHeight <= ::ChainActive().Height() + 2) { if ((!fAlreadyInFlight && nodestate->nBlocksInFlight < MAX_BLOCKS_IN_TRANSIT_PER_PEER) || - (fAlreadyInFlight && blockInFlightIt->second.first == pfrom->GetId())) { + (fAlreadyInFlight && blockInFlightIt->second.first == pfrom.GetId())) { std::list<QueuedBlock>::iterator* queuedBlockIt = nullptr; - if (!MarkBlockAsInFlight(mempool, pfrom->GetId(), pindex->GetBlockHash(), pindex, &queuedBlockIt)) { + if (!MarkBlockAsInFlight(mempool, pfrom.GetId(), pindex->GetBlockHash(), pindex, &queuedBlockIt)) { if (!(*queuedBlockIt)->partialBlock) (*queuedBlockIt)->partialBlock.reset(new PartiallyDownloadedBlock(&mempool)); else { // The block was already in flight using compact blocks from the same peer LogPrint(BCLog::NET, "Peer sent us compact block we were already syncing!\n"); - return true; + return; } } PartiallyDownloadedBlock& partialBlock = *(*queuedBlockIt)->partialBlock; ReadStatus status = partialBlock.InitData(cmpctblock, vExtraTxnForCompact); if (status == READ_STATUS_INVALID) { - MarkBlockAsReceived(pindex->GetBlockHash()); // Reset in-flight state in case of whitelist - Misbehaving(pfrom->GetId(), 100, strprintf("Peer %d sent us invalid compact block\n", pfrom->GetId())); - return true; + MarkBlockAsReceived(pindex->GetBlockHash()); // Reset in-flight state in case Misbehaving does not result in a disconnect + Misbehaving(pfrom.GetId(), 100, "invalid compact block"); + return; } else if (status == READ_STATUS_FAILED) { // Duplicate txindexes, the block is now in-flight, so just request it std::vector<CInv> vInv(1); vInv[0] = CInv(MSG_BLOCK | GetFetchFlags(pfrom), cmpctblock.header.GetHash()); - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETDATA, vInv)); - return true; + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETDATA, vInv)); + return; } BlockTransactionsRequest req; @@ -3077,7 +3269,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec fProcessBLOCKTXN = true; } else { req.blockhash = pindex->GetBlockHash(); - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETBLOCKTXN, req)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETBLOCKTXN, req)); } } else { // This block is either already in flight from a different @@ -3089,7 +3281,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec ReadStatus status = tempBlock.InitData(cmpctblock, vExtraTxnForCompact); if (status != READ_STATUS_OK) { // TODO: don't ignore failures - return true; + return; } std::vector<CTransactionRef> dummy; status = tempBlock.FillBlock(*pblock, dummy); @@ -3103,8 +3295,8 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // mempool will probably be useless - request the block normally std::vector<CInv> vInv(1); vInv[0] = CInv(MSG_BLOCK | GetFetchFlags(pfrom), cmpctblock.header.GetHash()); - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETDATA, vInv)); - return true; + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETDATA, vInv)); + return; } else { // If this was an announce-cmpctblock, we want the same treatment as a header message fRevertToHeaderProcessing = true; @@ -3113,14 +3305,14 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec } // cs_main if (fProcessBLOCKTXN) - return ProcessMessage(pfrom, NetMsgType::BLOCKTXN, blockTxnMsg, nTimeReceived, chainparams, chainman, mempool, connman, banman, interruptMsgProc); + return ProcessMessage(pfrom, NetMsgType::BLOCKTXN, blockTxnMsg, time_received, chainparams, chainman, mempool, connman, banman, interruptMsgProc); if (fRevertToHeaderProcessing) { // Headers received from HB compact block peers are permitted to be // relayed before full validation (see BIP 152), so we don't want to disconnect // the peer if the header turns out to be for an invalid block. // Note that if a peer tries to build on an invalid chain, that - // will be detected and the peer will be banned. + // will be detected and the peer will be disconnected/discouraged. return ProcessHeadersMessage(pfrom, connman, chainman, mempool, {cmpctblock.header}, chainparams, /*via_compact_block=*/true); } @@ -3129,7 +3321,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // block that is in flight from some other peer. { LOCK(cs_main); - mapBlockSource.emplace(pblock->GetHash(), std::make_pair(pfrom->GetId(), false)); + mapBlockSource.emplace(pblock->GetHash(), std::make_pair(pfrom.GetId(), false)); } bool fNewBlock = false; // Setting fForceProcessing to true means that we bypass some of @@ -3143,7 +3335,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // reconstructed compact blocks as having been requested. chainman.ProcessNewBlock(chainparams, pblock, /*fForceProcessing=*/true, &fNewBlock); if (fNewBlock) { - pfrom->nLastBlockTime = GetTime(); + pfrom.nLastBlockTime = GetTime(); } else { LOCK(cs_main); mapBlockSource.erase(pblock->GetHash()); @@ -3157,15 +3349,15 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec MarkBlockAsReceived(pblock->GetHash()); } } - return true; + return; } if (msg_type == NetMsgType::BLOCKTXN) { // Ignore blocktxn received while importing if (fImporting || fReindex) { - LogPrint(BCLog::NET, "Unexpected blocktxn message received from peer %d\n", pfrom->GetId()); - return true; + LogPrint(BCLog::NET, "Unexpected blocktxn message received from peer %d\n", pfrom.GetId()); + return; } BlockTransactions resp; @@ -3178,22 +3370,22 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec std::map<uint256, std::pair<NodeId, std::list<QueuedBlock>::iterator> >::iterator it = mapBlocksInFlight.find(resp.blockhash); if (it == mapBlocksInFlight.end() || !it->second.second->partialBlock || - it->second.first != pfrom->GetId()) { - LogPrint(BCLog::NET, "Peer %d sent us block transactions for block we weren't expecting\n", pfrom->GetId()); - return true; + it->second.first != pfrom.GetId()) { + LogPrint(BCLog::NET, "Peer %d sent us block transactions for block we weren't expecting\n", pfrom.GetId()); + return; } PartiallyDownloadedBlock& partialBlock = *it->second.second->partialBlock; ReadStatus status = partialBlock.FillBlock(*pblock, resp.txn); if (status == READ_STATUS_INVALID) { - MarkBlockAsReceived(resp.blockhash); // Reset in-flight state in case of whitelist - Misbehaving(pfrom->GetId(), 100, strprintf("Peer %d sent us invalid compact block/non-matching block transactions\n", pfrom->GetId())); - return true; + MarkBlockAsReceived(resp.blockhash); // Reset in-flight state in case Misbehaving does not result in a disconnect + Misbehaving(pfrom.GetId(), 100, "invalid compact block/non-matching block transactions"); + return; } else if (status == READ_STATUS_FAILED) { // Might have collided, fall back to getdata now :( std::vector<CInv> invs; invs.push_back(CInv(MSG_BLOCK | GetFetchFlags(pfrom), resp.blockhash)); - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::GETDATA, invs)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::GETDATA, invs)); } else { // Block is either okay, or possibly we received // READ_STATUS_CHECKBLOCK_FAILED. @@ -3206,7 +3398,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // 3. the block is otherwise invalid (eg invalid coinbase, // block is too big, too many legacy sigops, etc). // So if CheckBlock failed, #3 is the only possibility. - // Under BIP 152, we don't DoS-ban unless proof of work is + // Under BIP 152, we don't discourage the peer unless proof of work is // invalid (we don't require all the stateless checks to have // been run). This is handled below, so just treat this as // though the block was successfully read, and rely on the @@ -3220,7 +3412,7 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // BIP 152 permits peers to relay compact blocks after validating // the header only; we should not punish peers if the block turns // out to be invalid. - mapBlockSource.emplace(resp.blockhash, std::make_pair(pfrom->GetId(), false)); + mapBlockSource.emplace(resp.blockhash, std::make_pair(pfrom.GetId(), false)); } } // Don't hold cs_main when we call into ProcessNewBlock if (fBlockRead) { @@ -3233,21 +3425,21 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // in compact block optimistic reconstruction handling. chainman.ProcessNewBlock(chainparams, pblock, /*fForceProcessing=*/true, &fNewBlock); if (fNewBlock) { - pfrom->nLastBlockTime = GetTime(); + pfrom.nLastBlockTime = GetTime(); } else { LOCK(cs_main); mapBlockSource.erase(pblock->GetHash()); } } - return true; + return; } if (msg_type == NetMsgType::HEADERS) { // Ignore headers received while importing if (fImporting || fReindex) { - LogPrint(BCLog::NET, "Unexpected headers message received from peer %d\n", pfrom->GetId()); - return true; + LogPrint(BCLog::NET, "Unexpected headers message received from peer %d\n", pfrom.GetId()); + return; } std::vector<CBlockHeader> headers; @@ -3256,8 +3448,8 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec unsigned int nCount = ReadCompactSize(vRecv); if (nCount > MAX_HEADERS_RESULTS) { LOCK(cs_main); - Misbehaving(pfrom->GetId(), 20, strprintf("headers message size = %u", nCount)); - return false; + Misbehaving(pfrom.GetId(), 20, strprintf("headers message size = %u", nCount)); + return; } headers.resize(nCount); for (unsigned int n = 0; n < nCount; n++) { @@ -3272,14 +3464,14 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec { // Ignore block received while importing if (fImporting || fReindex) { - LogPrint(BCLog::NET, "Unexpected block message received from peer %d\n", pfrom->GetId()); - return true; + LogPrint(BCLog::NET, "Unexpected block message received from peer %d\n", pfrom.GetId()); + return; } std::shared_ptr<CBlock> pblock = std::make_shared<CBlock>(); vRecv >> *pblock; - LogPrint(BCLog::NET, "received block %s peer=%d\n", pblock->GetHash().ToString(), pfrom->GetId()); + LogPrint(BCLog::NET, "received block %s peer=%d\n", pblock->GetHash().ToString(), pfrom.GetId()); bool forceProcessing = false; const uint256 hash(pblock->GetHash()); @@ -3291,17 +3483,17 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // mapBlockSource is only used for punishing peers and setting // which peers send us compact blocks, so the race between here and // cs_main in ProcessNewBlock is fine. - mapBlockSource.emplace(hash, std::make_pair(pfrom->GetId(), true)); + mapBlockSource.emplace(hash, std::make_pair(pfrom.GetId(), true)); } bool fNewBlock = false; chainman.ProcessNewBlock(chainparams, pblock, forceProcessing, &fNewBlock); if (fNewBlock) { - pfrom->nLastBlockTime = GetTime(); + pfrom.nLastBlockTime = GetTime(); } else { LOCK(cs_main); mapBlockSource.erase(pblock->GetHash()); } - return true; + return; } if (msg_type == NetMsgType::GETADDR) { @@ -3310,64 +3502,67 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // to users' AddrMan and later request them by sending getaddr messages. // Making nodes which are behind NAT and can only make outgoing connections ignore // the getaddr message mitigates the attack. - if (!pfrom->fInbound) { - LogPrint(BCLog::NET, "Ignoring \"getaddr\" from outbound connection. peer=%d\n", pfrom->GetId()); - return true; + if (!pfrom.IsInboundConn()) { + LogPrint(BCLog::NET, "Ignoring \"getaddr\" from outbound connection. peer=%d\n", pfrom.GetId()); + return; } - if (!pfrom->IsAddrRelayPeer()) { - LogPrint(BCLog::NET, "Ignoring \"getaddr\" from block-relay-only connection. peer=%d\n", pfrom->GetId()); - return true; + if (!pfrom.IsAddrRelayPeer()) { + LogPrint(BCLog::NET, "Ignoring \"getaddr\" from block-relay-only connection. peer=%d\n", pfrom.GetId()); + return; } // Only send one GetAddr response per connection to reduce resource waste // and discourage addr stamping of INV announcements. - if (pfrom->fSentAddr) { - LogPrint(BCLog::NET, "Ignoring repeated \"getaddr\". peer=%d\n", pfrom->GetId()); - return true; + if (pfrom.fSentAddr) { + LogPrint(BCLog::NET, "Ignoring repeated \"getaddr\". peer=%d\n", pfrom.GetId()); + return; } - pfrom->fSentAddr = true; + pfrom.fSentAddr = true; - pfrom->vAddrToSend.clear(); - std::vector<CAddress> vAddr = connman->GetAddresses(); + pfrom.vAddrToSend.clear(); + std::vector<CAddress> vAddr; + if (pfrom.HasPermission(PF_ADDR)) { + vAddr = connman.GetAddresses(MAX_ADDR_TO_SEND, MAX_PCT_ADDR_TO_SEND); + } else { + vAddr = connman.GetAddresses(pfrom.addr.GetNetwork(), MAX_ADDR_TO_SEND, MAX_PCT_ADDR_TO_SEND); + } FastRandomContext insecure_rand; for (const CAddress &addr : vAddr) { - if (!banman->IsBanned(addr)) { - pfrom->PushAddress(addr, insecure_rand); - } + pfrom.PushAddress(addr, insecure_rand); } - return true; + return; } if (msg_type == NetMsgType::MEMPOOL) { - if (!(pfrom->GetLocalServices() & NODE_BLOOM) && !pfrom->HasPermission(PF_MEMPOOL)) + if (!(pfrom.GetLocalServices() & NODE_BLOOM) && !pfrom.HasPermission(PF_MEMPOOL)) { - if (!pfrom->HasPermission(PF_NOBAN)) + if (!pfrom.HasPermission(PF_NOBAN)) { - LogPrint(BCLog::NET, "mempool request with bloom filters disabled, disconnect peer=%d\n", pfrom->GetId()); - pfrom->fDisconnect = true; + LogPrint(BCLog::NET, "mempool request with bloom filters disabled, disconnect peer=%d\n", pfrom.GetId()); + pfrom.fDisconnect = true; } - return true; + return; } - if (connman->OutboundTargetReached(false) && !pfrom->HasPermission(PF_MEMPOOL)) + if (connman.OutboundTargetReached(false) && !pfrom.HasPermission(PF_MEMPOOL)) { - if (!pfrom->HasPermission(PF_NOBAN)) + if (!pfrom.HasPermission(PF_NOBAN)) { - LogPrint(BCLog::NET, "mempool request with bandwidth limit reached, disconnect peer=%d\n", pfrom->GetId()); - pfrom->fDisconnect = true; + LogPrint(BCLog::NET, "mempool request with bandwidth limit reached, disconnect peer=%d\n", pfrom.GetId()); + pfrom.fDisconnect = true; } - return true; + return; } - if (pfrom->m_tx_relay != nullptr) { - LOCK(pfrom->m_tx_relay->cs_tx_inventory); - pfrom->m_tx_relay->fSendMempool = true; + if (pfrom.m_tx_relay != nullptr) { + LOCK(pfrom.m_tx_relay->cs_tx_inventory); + pfrom.m_tx_relay->fSendMempool = true; } - return true; + return; } if (msg_type == NetMsgType::PING) { - if (pfrom->nVersion > BIP0031_VERSION) + if (pfrom.nVersion > BIP0031_VERSION) { uint64_t nonce = 0; vRecv >> nonce; @@ -3382,13 +3577,13 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec // it, if the remote node sends a ping once per second and this node takes 5 // seconds to respond to each, the 5th ping the remote sends would appear to // return very quickly. - connman->PushMessage(pfrom, msgMaker.Make(NetMsgType::PONG, nonce)); + connman.PushMessage(&pfrom, msgMaker.Make(NetMsgType::PONG, nonce)); } - return true; + return; } if (msg_type == NetMsgType::PONG) { - int64_t pingUsecEnd = nTimeReceived; + const auto ping_end = time_received; uint64_t nonce = 0; size_t nAvail = vRecv.in_avail(); bool bPingFinished = false; @@ -3398,15 +3593,15 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec vRecv >> nonce; // Only process pong message if there is an outstanding ping (old ping without nonce should never pong) - if (pfrom->nPingNonceSent != 0) { - if (nonce == pfrom->nPingNonceSent) { + if (pfrom.nPingNonceSent != 0) { + if (nonce == pfrom.nPingNonceSent) { // Matching pong received, this ping is no longer outstanding bPingFinished = true; - int64_t pingUsecTime = pingUsecEnd - pfrom->nPingUsecStart; - if (pingUsecTime > 0) { + const auto ping_time = ping_end - pfrom.m_ping_start.load(); + if (ping_time.count() > 0) { // Successful ping time measurement, replace previous - pfrom->nPingUsecTime = pingUsecTime; - pfrom->nMinPingUsecTime = std::min(pfrom->nMinPingUsecTime.load(), pingUsecTime); + pfrom.nPingUsecTime = count_microseconds(ping_time); + pfrom.nMinPingUsecTime = std::min(pfrom.nMinPingUsecTime.load(), count_microseconds(ping_time)); } else { // This should never happen sProblem = "Timing mishap"; @@ -3431,19 +3626,23 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec if (!(sProblem.empty())) { LogPrint(BCLog::NET, "pong peer=%d: %s, %x expected, %x received, %u bytes\n", - pfrom->GetId(), + pfrom.GetId(), sProblem, - pfrom->nPingNonceSent, + pfrom.nPingNonceSent, nonce, nAvail); } if (bPingFinished) { - pfrom->nPingNonceSent = 0; + pfrom.nPingNonceSent = 0; } - return true; + return; } if (msg_type == NetMsgType::FILTERLOAD) { + if (!(pfrom.GetLocalServices() & NODE_BLOOM)) { + pfrom.fDisconnect = true; + return; + } CBloomFilter filter; vRecv >> filter; @@ -3451,18 +3650,22 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec { // There is no excuse for sending a too-large filter LOCK(cs_main); - Misbehaving(pfrom->GetId(), 100); + Misbehaving(pfrom.GetId(), 100, "too-large bloom filter"); } - else if (pfrom->m_tx_relay != nullptr) + else if (pfrom.m_tx_relay != nullptr) { - LOCK(pfrom->m_tx_relay->cs_filter); - pfrom->m_tx_relay->pfilter.reset(new CBloomFilter(filter)); - pfrom->m_tx_relay->fRelayTxes = true; + LOCK(pfrom.m_tx_relay->cs_filter); + pfrom.m_tx_relay->pfilter.reset(new CBloomFilter(filter)); + pfrom.m_tx_relay->fRelayTxes = true; } - return true; + return; } if (msg_type == NetMsgType::FILTERADD) { + if (!(pfrom.GetLocalServices() & NODE_BLOOM)) { + pfrom.fDisconnect = true; + return; + } std::vector<unsigned char> vData; vRecv >> vData; @@ -3471,70 +3674,72 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec bool bad = false; if (vData.size() > MAX_SCRIPT_ELEMENT_SIZE) { bad = true; - } else if (pfrom->m_tx_relay != nullptr) { - LOCK(pfrom->m_tx_relay->cs_filter); - if (pfrom->m_tx_relay->pfilter) { - pfrom->m_tx_relay->pfilter->insert(vData); + } else if (pfrom.m_tx_relay != nullptr) { + LOCK(pfrom.m_tx_relay->cs_filter); + if (pfrom.m_tx_relay->pfilter) { + pfrom.m_tx_relay->pfilter->insert(vData); } else { bad = true; } } if (bad) { LOCK(cs_main); - Misbehaving(pfrom->GetId(), 100); + Misbehaving(pfrom.GetId(), 100, "bad filteradd message"); } - return true; + return; } if (msg_type == NetMsgType::FILTERCLEAR) { - if (pfrom->m_tx_relay == nullptr) { - return true; + if (!(pfrom.GetLocalServices() & NODE_BLOOM)) { + pfrom.fDisconnect = true; + return; } - LOCK(pfrom->m_tx_relay->cs_filter); - if (pfrom->GetLocalServices() & NODE_BLOOM) { - pfrom->m_tx_relay->pfilter = nullptr; + if (pfrom.m_tx_relay == nullptr) { + return; } - pfrom->m_tx_relay->fRelayTxes = true; - return true; + LOCK(pfrom.m_tx_relay->cs_filter); + pfrom.m_tx_relay->pfilter = nullptr; + pfrom.m_tx_relay->fRelayTxes = true; + return; } if (msg_type == NetMsgType::FEEFILTER) { CAmount newFeeFilter = 0; vRecv >> newFeeFilter; if (MoneyRange(newFeeFilter)) { - if (pfrom->m_tx_relay != nullptr) { - LOCK(pfrom->m_tx_relay->cs_feeFilter); - pfrom->m_tx_relay->minFeeFilter = newFeeFilter; + if (pfrom.m_tx_relay != nullptr) { + LOCK(pfrom.m_tx_relay->cs_feeFilter); + pfrom.m_tx_relay->minFeeFilter = newFeeFilter; } - LogPrint(BCLog::NET, "received: feefilter of %s from peer=%d\n", CFeeRate(newFeeFilter).ToString(), pfrom->GetId()); + LogPrint(BCLog::NET, "received: feefilter of %s from peer=%d\n", CFeeRate(newFeeFilter).ToString(), pfrom.GetId()); } - return true; + return; } if (msg_type == NetMsgType::GETCFILTERS) { - ProcessGetCFilters(*pfrom, vRecv, chainparams, *connman); - return true; + ProcessGetCFilters(pfrom, vRecv, chainparams, connman); + return; } if (msg_type == NetMsgType::GETCFHEADERS) { - ProcessGetCFHeaders(*pfrom, vRecv, chainparams, *connman); - return true; + ProcessGetCFHeaders(pfrom, vRecv, chainparams, connman); + return; } if (msg_type == NetMsgType::GETCFCHECKPT) { - ProcessGetCFCheckPt(*pfrom, vRecv, chainparams, *connman); - return true; + ProcessGetCFCheckPt(pfrom, vRecv, chainparams, connman); + return; } if (msg_type == NetMsgType::NOTFOUND) { // Remove the NOTFOUND transactions from the peer LOCK(cs_main); - CNodeState *state = State(pfrom->GetId()); + CNodeState *state = State(pfrom.GetId()); std::vector<CInv> vInv; vRecv >> vInv; if (vInv.size() <= MAX_PEER_TX_IN_FLIGHT + MAX_BLOCKS_IN_TRANSIT_PER_PEER) { for (CInv &inv : vInv) { - if (inv.type == MSG_TX || inv.type == MSG_WITNESS_TX) { + if (inv.IsGenTxMsg()) { // If we receive a NOTFOUND message for a txid we requested, erase // it from our data structures for this peer. auto in_flight_it = state->m_tx_download.m_tx_in_flight.find(inv.hash); @@ -3548,39 +3753,57 @@ bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRec } } } - return true; + return; } // Ignore unknown commands for extensibility - LogPrint(BCLog::NET, "Unknown command \"%s\" from peer=%d\n", SanitizeString(msg_type), pfrom->GetId()); - return true; + LogPrint(BCLog::NET, "Unknown command \"%s\" from peer=%d\n", SanitizeString(msg_type), pfrom.GetId()); + return; } -bool PeerLogicValidation::CheckIfBanned(CNode* pnode) +/** Maybe disconnect a peer and discourage future connections from its address. + * + * @param[in] pnode The node to check. + * @return True if the peer was marked for disconnection in this function + */ +bool PeerLogicValidation::MaybeDiscourageAndDisconnect(CNode& pnode) { - AssertLockHeld(cs_main); - CNodeState &state = *State(pnode->GetId()); - - if (state.fShouldBan) { - state.fShouldBan = false; - if (pnode->HasPermission(PF_NOBAN)) - LogPrintf("Warning: not punishing whitelisted peer %s!\n", pnode->addr.ToString()); - else if (pnode->m_manual_connection) - LogPrintf("Warning: not punishing manually-connected peer %s!\n", pnode->addr.ToString()); - else if (pnode->addr.IsLocal()) { - // Disconnect but don't ban _this_ local node - LogPrintf("Warning: disconnecting but not banning local peer %s!\n", pnode->addr.ToString()); - pnode->fDisconnect = true; - } else { - // Disconnect and ban all nodes sharing the address - if (m_banman) { - m_banman->Ban(pnode->addr, BanReasonNodeMisbehaving); - } - connman->DisconnectNode(pnode->addr); - } + const NodeId peer_id{pnode.GetId()}; + { + LOCK(cs_main); + CNodeState& state = *State(peer_id); + + // There's nothing to do if the m_should_discourage flag isn't set + if (!state.m_should_discourage) return false; + + state.m_should_discourage = false; + } // cs_main + + if (pnode.HasPermission(PF_NOBAN)) { + // We never disconnect or discourage peers for bad behavior if they have the NOBAN permission flag + LogPrintf("Warning: not punishing noban peer %d!\n", peer_id); + return false; + } + + if (pnode.IsManualConn()) { + // We never disconnect or discourage manual peers for bad behavior + LogPrintf("Warning: not punishing manually connected peer %d!\n", peer_id); + return false; + } + + if (pnode.addr.IsLocal()) { + // We disconnect local peers for bad behavior but don't discourage (since that would discourage + // all peers on the same local address) + LogPrintf("Warning: disconnecting but not discouraging local peer %d!\n", peer_id); + pnode.fDisconnect = true; return true; } - return false; + + // Normal case: Disconnect the peer and discourage all nodes sharing the address + LogPrintf("Disconnecting and discouraging peer %d!\n", peer_id); + if (m_banman) m_banman->Discourage(pnode.addr); + connman->DisconnectNode(pnode.addr); + return true; } bool PeerLogicValidation::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgProc) @@ -3597,12 +3820,12 @@ bool PeerLogicValidation::ProcessMessages(CNode* pfrom, std::atomic<bool>& inter bool fMoreWork = false; if (!pfrom->vRecvGetData.empty()) - ProcessGetData(pfrom, chainparams, connman, m_mempool, interruptMsgProc); + ProcessGetData(*pfrom, chainparams, *connman, m_mempool, interruptMsgProc); if (!pfrom->orphan_work_set.empty()) { std::list<CTransactionRef> removed_txn; LOCK2(cs_main, g_cs_orphans); - ProcessOrphanTx(connman, m_mempool, pfrom->orphan_work_set, removed_txn); + ProcessOrphanTx(*connman, m_mempool, pfrom->orphan_work_set, removed_txn); for (const CTransactionRef& removedTx : removed_txn) { AddToCompactExtraTransactions(removedTx); } @@ -3661,11 +3884,8 @@ bool PeerLogicValidation::ProcessMessages(CNode* pfrom, std::atomic<bool>& inter return fMoreWork; } - // Process message - bool fRet = false; - try - { - fRet = ProcessMessage(pfrom, msg_type, vRecv, msg.m_time, chainparams, m_chainman, m_mempool, connman, m_banman, interruptMsgProc); + try { + ProcessMessage(*pfrom, msg_type, vRecv, msg.m_time, chainparams, m_chainman, m_mempool, *connman, m_banman, interruptMsgProc); if (interruptMsgProc) return false; if (!pfrom->vRecvGetData.empty()) @@ -3676,24 +3896,17 @@ bool PeerLogicValidation::ProcessMessages(CNode* pfrom, std::atomic<bool>& inter LogPrint(BCLog::NET, "%s(%s, %u bytes): Unknown exception caught\n", __func__, SanitizeString(msg_type), nMessageSize); } - if (!fRet) { - LogPrint(BCLog::NET, "%s(%s, %u bytes) FAILED peer=%d\n", __func__, SanitizeString(msg_type), nMessageSize, pfrom->GetId()); - } - - LOCK(cs_main); - CheckIfBanned(pfrom); - return fMoreWork; } -void PeerLogicValidation::ConsiderEviction(CNode *pto, int64_t time_in_seconds) +void PeerLogicValidation::ConsiderEviction(CNode& pto, int64_t time_in_seconds) { AssertLockHeld(cs_main); - CNodeState &state = *State(pto->GetId()); - const CNetMsgMaker msgMaker(pto->GetSendVersion()); + CNodeState &state = *State(pto.GetId()); + const CNetMsgMaker msgMaker(pto.GetSendVersion()); - if (!state.m_chain_sync.m_protect && IsOutboundDisconnectionCandidate(pto) && state.fSyncStarted) { + if (!state.m_chain_sync.m_protect && pto.IsOutboundOrBlockRelayConn() && state.fSyncStarted) { // This is an outbound peer subject to disconnection if they don't // announce a block with as much work as the current tip within // CHAIN_SYNC_TIMEOUT + HEADERS_RESPONSE_TIME seconds (note: if @@ -3720,12 +3933,12 @@ void PeerLogicValidation::ConsiderEviction(CNode *pto, int64_t time_in_seconds) // message to give the peer a chance to update us. if (state.m_chain_sync.m_sent_getheaders) { // They've run out of time to catch up! - LogPrintf("Disconnecting outbound peer %d for old chain, best known block = %s\n", pto->GetId(), state.pindexBestKnownBlock != nullptr ? state.pindexBestKnownBlock->GetBlockHash().ToString() : "<none>"); - pto->fDisconnect = true; + LogPrintf("Disconnecting outbound peer %d for old chain, best known block = %s\n", pto.GetId(), state.pindexBestKnownBlock != nullptr ? state.pindexBestKnownBlock->GetBlockHash().ToString() : "<none>"); + pto.fDisconnect = true; } else { assert(state.m_chain_sync.m_work_header); - LogPrint(BCLog::NET, "sending getheaders to outbound peer=%d to verify chain work (current best known block:%s, benchmark blockhash: %s)\n", pto->GetId(), state.pindexBestKnownBlock != nullptr ? state.pindexBestKnownBlock->GetBlockHash().ToString() : "<none>", state.m_chain_sync.m_work_header->GetBlockHash().ToString()); - connman->PushMessage(pto, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(state.m_chain_sync.m_work_header->pprev), uint256())); + LogPrint(BCLog::NET, "sending getheaders to outbound peer=%d to verify chain work (current best known block:%s, benchmark blockhash: %s)\n", pto.GetId(), state.pindexBestKnownBlock != nullptr ? state.pindexBestKnownBlock->GetBlockHash().ToString() : "<none>", state.m_chain_sync.m_work_header->GetBlockHash().ToString()); + connman->PushMessage(&pto, msgMaker.Make(NetMsgType::GETHEADERS, ::ChainActive().GetLocator(state.m_chain_sync.m_work_header->pprev), uint256())); state.m_chain_sync.m_sent_getheaders = true; constexpr int64_t HEADERS_RESPONSE_TIME = 120; // 2 minutes // Bump the timeout to allow a response, which could clear the timeout @@ -3755,7 +3968,7 @@ void PeerLogicValidation::EvictExtraOutboundPeers(int64_t time_in_seconds) AssertLockHeld(cs_main); // Ignore non-outbound peers, or nodes marked for disconnect already - if (!IsOutboundDisconnectionCandidate(pnode) || pnode->fDisconnect) return; + if (!pnode->IsOutboundOrBlockRelayConn() || pnode->fDisconnect) return; CNodeState *state = State(pnode->GetId()); if (state == nullptr) return; // shouldn't be possible, but just in case // Don't evict our protected peers @@ -3825,17 +4038,19 @@ namespace { class CompareInvMempoolOrder { CTxMemPool *mp; + bool m_wtxid_relay; public: - explicit CompareInvMempoolOrder(CTxMemPool *_mempool) + explicit CompareInvMempoolOrder(CTxMemPool *_mempool, bool use_wtxid) { mp = _mempool; + m_wtxid_relay = use_wtxid; } bool operator()(std::set<uint256>::iterator a, std::set<uint256>::iterator b) { /* As std::make_heap produces a max-heap, we want the entries with the * fewest ancestors/highest fee to sort later. */ - return mp->CompareDepthAndScore(*b, *a); + return mp->CompareDepthAndScore(*b, *a, m_wtxid_relay); } }; } @@ -3843,48 +4058,49 @@ public: bool PeerLogicValidation::SendMessages(CNode* pto) { const Consensus::Params& consensusParams = Params().GetConsensus(); - { - // Don't send anything until the version handshake is complete - if (!pto->fSuccessfullyConnected || pto->fDisconnect) - return true; - // If we get here, the outgoing message serialization version is set and can't change. - const CNetMsgMaker msgMaker(pto->GetSendVersion()); + // We must call MaybeDiscourageAndDisconnect first, to ensure that we'll + // disconnect misbehaving peers even before the version handshake is complete. + if (MaybeDiscourageAndDisconnect(*pto)) return true; - // - // Message: ping - // - bool pingSend = false; - if (pto->fPingQueued) { - // RPC ping request by user - pingSend = true; - } - if (pto->nPingNonceSent == 0 && pto->nPingUsecStart + PING_INTERVAL * 1000000 < GetTimeMicros()) { - // Ping automatically sent as a latency probe & keepalive. - pingSend = true; - } - if (pingSend) { - uint64_t nonce = 0; - while (nonce == 0) { - GetRandBytes((unsigned char*)&nonce, sizeof(nonce)); - } - pto->fPingQueued = false; - pto->nPingUsecStart = GetTimeMicros(); - if (pto->nVersion > BIP0031_VERSION) { - pto->nPingNonceSent = nonce; - connman->PushMessage(pto, msgMaker.Make(NetMsgType::PING, nonce)); - } else { - // Peer is too old to support ping command with nonce, pong will never arrive. - pto->nPingNonceSent = 0; - connman->PushMessage(pto, msgMaker.Make(NetMsgType::PING)); - } - } + // Don't send anything until the version handshake is complete + if (!pto->fSuccessfullyConnected || pto->fDisconnect) + return true; - TRY_LOCK(cs_main, lockMain); - if (!lockMain) - return true; + // If we get here, the outgoing message serialization version is set and can't change. + const CNetMsgMaker msgMaker(pto->GetSendVersion()); - if (CheckIfBanned(pto)) return true; + // + // Message: ping + // + bool pingSend = false; + if (pto->fPingQueued) { + // RPC ping request by user + pingSend = true; + } + if (pto->nPingNonceSent == 0 && pto->m_ping_start.load() + PING_INTERVAL < GetTime<std::chrono::microseconds>()) { + // Ping automatically sent as a latency probe & keepalive. + pingSend = true; + } + if (pingSend) { + uint64_t nonce = 0; + while (nonce == 0) { + GetRandBytes((unsigned char*)&nonce, sizeof(nonce)); + } + pto->fPingQueued = false; + pto->m_ping_start = GetTime<std::chrono::microseconds>(); + if (pto->nVersion > BIP0031_VERSION) { + pto->nPingNonceSent = nonce; + connman->PushMessage(pto, msgMaker.Make(NetMsgType::PING, nonce)); + } else { + // Peer is too old to support ping command with nonce, pong will never arrive. + pto->nPingNonceSent = 0; + connman->PushMessage(pto, msgMaker.Make(NetMsgType::PING)); + } + } + + { + LOCK(cs_main); CNodeState &state = *State(pto->GetId()); @@ -3911,8 +4127,8 @@ bool PeerLogicValidation::SendMessages(CNode* pto) { pto->m_addr_known->insert(addr.GetKey()); vAddr.push_back(addr); - // receiver rejects addr messages larger than 1000 - if (vAddr.size() >= 1000) + // receiver rejects addr messages larger than MAX_ADDR_TO_SEND + if (vAddr.size() >= MAX_ADDR_TO_SEND) { connman->PushMessage(pto, msgMaker.Make(NetMsgType::ADDR, vAddr)); vAddr.clear(); @@ -3930,7 +4146,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) // Start block sync if (pindexBestHeader == nullptr) pindexBestHeader = ::ChainActive().Tip(); - bool fFetch = state.fPreferredDownload || (nPreferredDownload == 0 && !pto->fClient && !pto->fOneShot); // Download if this is a nice peer, or we have no nice peers and this one might do. + bool fFetch = state.fPreferredDownload || (nPreferredDownload == 0 && !pto->fClient && !pto->IsAddrFetchConn()); // Download if this is a nice peer, or we have no nice peers and this one might do. if (!state.fSyncStarted && !pto->fClient && !fImporting && !fReindex) { // Only actively request headers from a single peer, unless we're close to today. if ((nSyncStarted == 0 && fFetch) || pindexBestHeader->GetBlockTime() > GetAdjustedTime() - 24 * 60 * 60) { @@ -4082,7 +4298,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) // If the peer's chain has this block, don't inv it back. if (!PeerHasHeader(&state, pindex)) { - pto->PushInventory(CInv(MSG_BLOCK, hashToAnnounce)); + pto->vInventoryBlockToSend.push_back(hashToAnnounce); LogPrint(BCLog::NET, "%s: sending inv peer=%d hash=%s\n", __func__, pto->GetId(), hashToAnnounce.ToString()); } @@ -4115,7 +4331,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) bool fSendTrickle = pto->HasPermission(PF_NOBAN); if (pto->m_tx_relay->nNextInvSend < current_time) { fSendTrickle = true; - if (pto->fInbound) { + if (pto->IsInboundConn()) { pto->m_tx_relay->nNextInvSend = std::chrono::microseconds{connman->PoissonNextSendInbound(nNow, INVENTORY_BROADCAST_INTERVAL)}; } else { // Use half the delay for outbound peers, as there is less privacy concern for them. @@ -4142,8 +4358,8 @@ bool PeerLogicValidation::SendMessages(CNode* pto) LOCK(pto->m_tx_relay->cs_filter); for (const auto& txinfo : vtxinfo) { - const uint256& hash = txinfo.tx->GetHash(); - CInv inv(MSG_TX, hash); + const uint256& hash = state.m_wtxid_relay ? txinfo.tx->GetWitnessHash() : txinfo.tx->GetHash(); + CInv inv(state.m_wtxid_relay ? MSG_WTX : MSG_TX, hash); pto->m_tx_relay->setInventoryTxToSend.erase(hash); // Don't send transactions that peers will not put into their mempool if (txinfo.fee < filterrate.GetFee(txinfo.vsize)) { @@ -4153,6 +4369,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) if (!pto->m_tx_relay->pfilter->IsRelevantAndUpdate(*txinfo.tx)) continue; } pto->m_tx_relay->filterInventoryKnown.insert(hash); + // Responses to MEMPOOL requests bypass the m_recently_announced_invs filter. vInv.push_back(inv); if (vInv.size() == MAX_INV_SZ) { connman->PushMessage(pto, msgMaker.Make(NetMsgType::INV, vInv)); @@ -4177,7 +4394,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) } // Topologically and fee-rate sort the inventory we send for privacy and priority reasons. // A heap is used so that not all items need sorting if only a few are being sent. - CompareInvMempoolOrder compareInvMempoolOrder(&m_mempool); + CompareInvMempoolOrder compareInvMempoolOrder(&m_mempool, state.m_wtxid_relay); std::make_heap(vInvTx.begin(), vInvTx.end(), compareInvMempoolOrder); // No reason to drain out at many times the network's capacity, // especially since we have many peers and some will draw much shorter delays. @@ -4189,6 +4406,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) std::set<uint256>::iterator it = vInvTx.back(); vInvTx.pop_back(); uint256 hash = *it; + CInv inv(state.m_wtxid_relay ? MSG_WTX : MSG_TX, hash); // Remove it from the to-be-sent set pto->m_tx_relay->setInventoryTxToSend.erase(it); // Check if not in the filter already @@ -4196,17 +4414,20 @@ bool PeerLogicValidation::SendMessages(CNode* pto) continue; } // Not in the mempool anymore? don't bother sending it. - auto txinfo = m_mempool.info(hash); + auto txinfo = m_mempool.info(ToGenTxid(inv)); if (!txinfo.tx) { continue; } + auto txid = txinfo.tx->GetHash(); + auto wtxid = txinfo.tx->GetWitnessHash(); // Peer told you to not send transactions at that feerate? Don't bother sending it. if (txinfo.fee < filterrate.GetFee(txinfo.vsize)) { continue; } if (pto->m_tx_relay->pfilter && !pto->m_tx_relay->pfilter->IsRelevantAndUpdate(*txinfo.tx)) continue; // Send - vInv.push_back(CInv(MSG_TX, hash)); + State(pto->GetId())->m_recently_announced_invs.insert(hash); + vInv.push_back(inv); nRelayedTransactions++; { // Expire old relay messages @@ -4216,9 +4437,14 @@ bool PeerLogicValidation::SendMessages(CNode* pto) vRelayExpiration.pop_front(); } - auto ret = mapRelay.insert(std::make_pair(hash, std::move(txinfo.tx))); + auto ret = mapRelay.emplace(txid, std::move(txinfo.tx)); if (ret.second) { - vRelayExpiration.push_back(std::make_pair(nNow + std::chrono::microseconds{RELAY_TX_CACHE_TIME}.count(), ret.first)); + vRelayExpiration.emplace_back(nNow + std::chrono::microseconds{RELAY_TX_CACHE_TIME}.count(), ret.first); + } + // Add wtxid-based lookup into mapRelay as well, so that peers can request by wtxid + auto ret2 = mapRelay.emplace(wtxid, ret.first->second); + if (ret2.second) { + vRelayExpiration.emplace_back(nNow + std::chrono::microseconds{RELAY_TX_CACHE_TIME}.count(), ret2.first); } } if (vInv.size() == MAX_INV_SZ) { @@ -4226,6 +4452,14 @@ bool PeerLogicValidation::SendMessages(CNode* pto) vInv.clear(); } pto->m_tx_relay->filterInventoryKnown.insert(hash); + if (hash != txid) { + // Insert txid into filterInventoryKnown, even for + // wtxidrelay peers. This prevents re-adding of + // unconfirmed parents to the recently_announced + // filter, when a child tx is requested. See + // ProcessGetData(). + pto->m_tx_relay->filterInventoryKnown.insert(txid); + } } } } @@ -4263,9 +4497,9 @@ bool PeerLogicValidation::SendMessages(CNode* pto) // Check for headers sync timeouts if (state.fSyncStarted && state.nHeadersSyncTimeout < std::numeric_limits<int64_t>::max()) { // Detect whether this is a stalling initial-headers-sync peer - if (pindexBestHeader->GetBlockTime() <= GetAdjustedTime() - 24*60*60) { + if (pindexBestHeader->GetBlockTime() <= GetAdjustedTime() - 24 * 60 * 60) { if (nNow > state.nHeadersSyncTimeout && nSyncStarted == 1 && (nPreferredDownload - state.fPreferredDownload >= 1)) { - // Disconnect a (non-whitelisted) peer if it is our only sync peer, + // Disconnect a peer (without the noban permission) if it is our only sync peer, // and we have others we could be using instead. // Note: If all our peers are inbound, then we won't // disconnect our sync peer for stalling; we have bigger @@ -4275,7 +4509,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) pto->fDisconnect = true; return true; } else { - LogPrintf("Timeout downloading headers from whitelisted peer=%d, not disconnecting\n", pto->GetId()); + LogPrintf("Timeout downloading headers from noban peer=%d, not disconnecting\n", pto->GetId()); // Reset the headers sync state so that we have a // chance to try downloading from a different peer. // Note: this will also result in at least one more @@ -4295,7 +4529,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) // Check that outbound peers have reasonable chains // GetTime() is used by this anti-DoS logic so we can test this using mocktime - ConsiderEviction(pto, GetTime()); + ConsiderEviction(*pto, GetTime()); // // Message: getdata (blocks) @@ -4306,7 +4540,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) NodeId staller = -1; FindNextBlocksToDownload(pto->GetId(), MAX_BLOCKS_IN_TRANSIT_PER_PEER - state.nBlocksInFlight, vToDownload, staller, consensusParams); for (const CBlockIndex *pindex : vToDownload) { - uint32_t nFetchFlags = GetFetchFlags(pto); + uint32_t nFetchFlags = GetFetchFlags(*pto); vGetData.push_back(CInv(MSG_BLOCK | nFetchFlags, pindex->GetBlockHash())); MarkBlockAsInFlight(m_mempool, pto->GetId(), pindex->GetBlockHash(), pindex); LogPrint(BCLog::NET, "Requesting block %s (%d) peer=%d\n", pindex->GetBlockHash().ToString(), @@ -4346,15 +4580,15 @@ bool PeerLogicValidation::SendMessages(CNode* pto) auto& tx_process_time = state.m_tx_download.m_tx_process_time; while (!tx_process_time.empty() && tx_process_time.begin()->first <= current_time && state.m_tx_download.m_tx_in_flight.size() < MAX_PEER_TX_IN_FLIGHT) { - const uint256 txid = tx_process_time.begin()->second; + const GenTxid gtxid = tx_process_time.begin()->second; // Erase this entry from tx_process_time (it may be added back for // processing at a later time, see below) tx_process_time.erase(tx_process_time.begin()); - CInv inv(MSG_TX | GetFetchFlags(pto), txid); + CInv inv(gtxid.IsWtxid() ? MSG_WTX : (MSG_TX | GetFetchFlags(*pto)), gtxid.GetHash()); if (!AlreadyHave(inv, m_mempool)) { // If this transaction was last requested more than 1 minute ago, // then request. - const auto last_request_time = GetTxRequestTime(inv.hash); + const auto last_request_time = GetTxRequestTime(gtxid); if (last_request_time <= current_time - GETDATA_TX_INTERVAL) { LogPrint(BCLog::NET, "Requesting %s peer=%d\n", inv.ToString(), pto->GetId()); vGetData.push_back(inv); @@ -4362,20 +4596,28 @@ bool PeerLogicValidation::SendMessages(CNode* pto) connman->PushMessage(pto, msgMaker.Make(NetMsgType::GETDATA, vGetData)); vGetData.clear(); } - UpdateTxRequestTime(inv.hash, current_time); - state.m_tx_download.m_tx_in_flight.emplace(inv.hash, current_time); + UpdateTxRequestTime(gtxid, current_time); + state.m_tx_download.m_tx_in_flight.emplace(gtxid.GetHash(), current_time); } else { // This transaction is in flight from someone else; queue // up processing to happen after the download times out // (with a slight delay for inbound peers, to prefer // requests to outbound peers). - const auto next_process_time = CalculateTxGetDataTime(txid, current_time, !state.fPreferredDownload); - tx_process_time.emplace(next_process_time, txid); + // Don't apply the txid-delay to re-requests of a + // transaction; the heuristic of delaying requests to + // txid-relay peers is to save bandwidth on initial + // announcement of a transaction, and doesn't make sense + // for a followup request if our first peer times out (and + // would open us up to an attacker using inbound + // wtxid-relay to prevent us from requesting transactions + // from outbound txid-relay peers). + const auto next_process_time = CalculateTxGetDataTime(gtxid, current_time, !state.fPreferredDownload, false); + tx_process_time.emplace(next_process_time, gtxid); } } else { // We have already seen this transaction, no need to download. - state.m_tx_download.m_tx_announced.erase(inv.hash); - state.m_tx_download.m_tx_in_flight.erase(inv.hash); + state.m_tx_download.m_tx_announced.erase(gtxid.GetHash()); + state.m_tx_download.m_tx_in_flight.erase(gtxid.GetHash()); } } @@ -4386,15 +4628,26 @@ bool PeerLogicValidation::SendMessages(CNode* pto) // // Message: feefilter // - // We don't want white listed peers to filter txs to us if we have -whitelistforcerelay if (pto->m_tx_relay != nullptr && pto->nVersion >= FEEFILTER_VERSION && gArgs.GetBoolArg("-feefilter", DEFAULT_FEEFILTER) && - !pto->HasPermission(PF_FORCERELAY)) { + !pto->HasPermission(PF_FORCERELAY) // peers with the forcerelay permission should not filter txs to us + ) { CAmount currentFilter = m_mempool.GetMinFee(gArgs.GetArg("-maxmempool", DEFAULT_MAX_MEMPOOL_SIZE) * 1000000).GetFeePerK(); int64_t timeNow = GetTimeMicros(); + static FeeFilterRounder g_filter_rounder{CFeeRate{DEFAULT_MIN_RELAY_TX_FEE}}; + if (m_chainman.ActiveChainstate().IsInitialBlockDownload()) { + // Received tx-inv messages are discarded when the active + // chainstate is in IBD, so tell the peer to not send them. + currentFilter = MAX_MONEY; + } else { + static const CAmount MAX_FILTER{g_filter_rounder.round(MAX_MONEY)}; + if (pto->m_tx_relay->lastSentFeeFilter == MAX_FILTER) { + // Send the current filter if we sent MAX_FILTER previously + // and made it out of IBD. + pto->m_tx_relay->nextSendTimeFeeFilter = timeNow - 1; + } + } if (timeNow > pto->m_tx_relay->nextSendTimeFeeFilter) { - static CFeeRate default_feerate(DEFAULT_MIN_RELAY_TX_FEE); - static FeeFilterRounder filterRounder(default_feerate); - CAmount filterToSend = filterRounder.round(currentFilter); + CAmount filterToSend = g_filter_rounder.round(currentFilter); // We always have a fee filter of at least minRelayTxFee filterToSend = std::max(filterToSend, ::minRelayTxFee.GetFeePerK()); if (filterToSend != pto->m_tx_relay->lastSentFeeFilter) { @@ -4410,7 +4663,7 @@ bool PeerLogicValidation::SendMessages(CNode* pto) pto->m_tx_relay->nextSendTimeFeeFilter = timeNow + GetRandInt(MAX_FEEFILTER_CHANGE_DELAY) * 1000000; } } - } + } // release cs_main return true; } @@ -4422,6 +4675,7 @@ public: // orphan transactions mapOrphanTransactions.clear(); mapOrphanTransactionsByPrev.clear(); + g_orphans_by_wtxid.clear(); } }; static CNetProcessingCleanup instance_of_cnetprocessingcleanup; diff --git a/src/net_processing.h b/src/net_processing.h index ec758c7537..2d98714122 100644 --- a/src/net_processing.h +++ b/src/net_processing.h @@ -23,15 +23,18 @@ static const unsigned int DEFAULT_MAX_ORPHAN_TRANSACTIONS = 100; static const unsigned int DEFAULT_BLOCK_RECONSTRUCTION_EXTRA_TXN = 100; static const bool DEFAULT_PEERBLOOMFILTERS = false; static const bool DEFAULT_PEERBLOCKFILTERS = false; +/** Threshold for marking a node to be discouraged, e.g. disconnected and added to the discouragement filter. */ +static const int DISCOURAGEMENT_THRESHOLD{100}; class PeerLogicValidation final : public CValidationInterface, public NetEventsInterface { private: CConnman* const connman; + /** Pointer to this node's banman. May be nullptr - check existence before dereferencing. */ BanMan* const m_banman; ChainstateManager& m_chainman; CTxMemPool& m_mempool; - bool CheckIfBanned(CNode* pnode) EXCLUSIVE_LOCKS_REQUIRED(cs_main); + bool MaybeDiscourageAndDisconnect(CNode& pnode); public: PeerLogicValidation(CConnman* connman, BanMan* banman, CScheduler& scheduler, ChainstateManager& chainman, CTxMemPool& pool); @@ -74,7 +77,7 @@ public: bool SendMessages(CNode* pto) override EXCLUSIVE_LOCKS_REQUIRED(pto->cs_sendProcessing); /** Consider evicting an outbound peer based on the amount of time they've been behind our tip */ - void ConsiderEviction(CNode *pto, int64_t time_in_seconds) EXCLUSIVE_LOCKS_REQUIRED(cs_main); + void ConsiderEviction(CNode& pto, int64_t time_in_seconds) EXCLUSIVE_LOCKS_REQUIRED(cs_main); /** Evict extra outbound peers. If we think our tip may be stale, connect to an extra outbound */ void CheckForStaleTipAndEvictPeers(const Consensus::Params &consensusParams); /** If we have extra outbound peers, try to disconnect the one with the oldest block announcement */ @@ -97,6 +100,6 @@ struct CNodeStateStats { bool GetNodeStateStats(NodeId nodeid, CNodeStateStats &stats); /** Relay transaction to every node */ -void RelayTransaction(const uint256&, const CConnman& connman); +void RelayTransaction(const uint256& txid, const uint256& wtxid, const CConnman& connman) EXCLUSIVE_LOCKS_REQUIRED(cs_main); #endif // BITCOIN_NET_PROCESSING_H diff --git a/src/netaddress.cpp b/src/netaddress.cpp index f79425a52e..d29aed6c8b 100644 --- a/src/netaddress.cpp +++ b/src/netaddress.cpp @@ -3,6 +3,7 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. +#include <cstdint> #include <netaddress.h> #include <hash.h> #include <util/strencodings.h> @@ -27,19 +28,35 @@ CNetAddr::CNetAddr() void CNetAddr::SetIP(const CNetAddr& ipIn) { + m_net = ipIn.m_net; memcpy(ip, ipIn.ip, sizeof(ip)); } +void CNetAddr::SetLegacyIPv6(const uint8_t ipv6[16]) +{ + if (memcmp(ipv6, pchIPv4, sizeof(pchIPv4)) == 0) { + m_net = NET_IPV4; + } else if (memcmp(ipv6, pchOnionCat, sizeof(pchOnionCat)) == 0) { + m_net = NET_ONION; + } else if (memcmp(ipv6, g_internal_prefix, sizeof(g_internal_prefix)) == 0) { + m_net = NET_INTERNAL; + } else { + m_net = NET_IPV6; + } + memcpy(ip, ipv6, 16); +} + void CNetAddr::SetRaw(Network network, const uint8_t *ip_in) { switch(network) { case NET_IPV4: + m_net = NET_IPV4; memcpy(ip, pchIPv4, 12); memcpy(ip+12, ip_in, 4); break; case NET_IPV6: - memcpy(ip, ip_in, 16); + SetLegacyIPv6(ip_in); break; default: assert(!"invalid network"); @@ -65,6 +82,7 @@ bool CNetAddr::SetInternal(const std::string &name) if (name.empty()) { return false; } + m_net = NET_INTERNAL; unsigned char hash[32] = {}; CSHA256().Write((const unsigned char*)name.data(), name.size()).Finalize(hash); memcpy(ip, g_internal_prefix, sizeof(g_internal_prefix)); @@ -88,6 +106,7 @@ bool CNetAddr::SetSpecial(const std::string &strName) std::vector<unsigned char> vchAddr = DecodeBase32(strName.substr(0, strName.size() - 6).c_str()); if (vchAddr.size() != 16-sizeof(pchOnionCat)) return false; + m_net = NET_ONION; memcpy(ip, pchOnionCat, sizeof(pchOnionCat)); for (unsigned int i=0; i<16-sizeof(pchOnionCat); i++) ip[i + sizeof(pchOnionCat)] = vchAddr[i]; @@ -122,15 +141,9 @@ bool CNetAddr::IsBindAny() const return true; } -bool CNetAddr::IsIPv4() const -{ - return (memcmp(ip, pchIPv4, sizeof(pchIPv4)) == 0); -} +bool CNetAddr::IsIPv4() const { return m_net == NET_IPV4; } -bool CNetAddr::IsIPv6() const -{ - return (!IsIPv4() && !IsTor() && !IsInternal()); -} +bool CNetAddr::IsIPv6() const { return m_net == NET_IPV6; } bool CNetAddr::IsRFC1918() const { @@ -164,50 +177,54 @@ bool CNetAddr::IsRFC5737() const bool CNetAddr::IsRFC3849() const { - return GetByte(15) == 0x20 && GetByte(14) == 0x01 && GetByte(13) == 0x0D && GetByte(12) == 0xB8; + return IsIPv6() && GetByte(15) == 0x20 && GetByte(14) == 0x01 && + GetByte(13) == 0x0D && GetByte(12) == 0xB8; } bool CNetAddr::IsRFC3964() const { - return (GetByte(15) == 0x20 && GetByte(14) == 0x02); + return IsIPv6() && GetByte(15) == 0x20 && GetByte(14) == 0x02; } bool CNetAddr::IsRFC6052() const { static const unsigned char pchRFC6052[] = {0,0x64,0xFF,0x9B,0,0,0,0,0,0,0,0}; - return (memcmp(ip, pchRFC6052, sizeof(pchRFC6052)) == 0); + return IsIPv6() && memcmp(ip, pchRFC6052, sizeof(pchRFC6052)) == 0; } bool CNetAddr::IsRFC4380() const { - return (GetByte(15) == 0x20 && GetByte(14) == 0x01 && GetByte(13) == 0 && GetByte(12) == 0); + return IsIPv6() && GetByte(15) == 0x20 && GetByte(14) == 0x01 && GetByte(13) == 0 && + GetByte(12) == 0; } bool CNetAddr::IsRFC4862() const { static const unsigned char pchRFC4862[] = {0xFE,0x80,0,0,0,0,0,0}; - return (memcmp(ip, pchRFC4862, sizeof(pchRFC4862)) == 0); + return IsIPv6() && memcmp(ip, pchRFC4862, sizeof(pchRFC4862)) == 0; } bool CNetAddr::IsRFC4193() const { - return ((GetByte(15) & 0xFE) == 0xFC); + return IsIPv6() && (GetByte(15) & 0xFE) == 0xFC; } bool CNetAddr::IsRFC6145() const { static const unsigned char pchRFC6145[] = {0,0,0,0,0,0,0,0,0xFF,0xFF,0,0}; - return (memcmp(ip, pchRFC6145, sizeof(pchRFC6145)) == 0); + return IsIPv6() && memcmp(ip, pchRFC6145, sizeof(pchRFC6145)) == 0; } bool CNetAddr::IsRFC4843() const { - return (GetByte(15) == 0x20 && GetByte(14) == 0x01 && GetByte(13) == 0x00 && (GetByte(12) & 0xF0) == 0x10); + return IsIPv6() && GetByte(15) == 0x20 && GetByte(14) == 0x01 && + GetByte(13) == 0x00 && (GetByte(12) & 0xF0) == 0x10; } bool CNetAddr::IsRFC7343() const { - return (GetByte(15) == 0x20 && GetByte(14) == 0x01 && GetByte(13) == 0x00 && (GetByte(12) & 0xF0) == 0x20); + return IsIPv6() && GetByte(15) == 0x20 && GetByte(14) == 0x01 && + GetByte(13) == 0x00 && (GetByte(12) & 0xF0) == 0x20; } bool CNetAddr::IsHeNet() const @@ -221,10 +238,7 @@ bool CNetAddr::IsHeNet() const * * @see CNetAddr::SetSpecial(const std::string &) */ -bool CNetAddr::IsTor() const -{ - return (memcmp(ip, pchOnionCat, sizeof(pchOnionCat)) == 0); -} +bool CNetAddr::IsTor() const { return m_net == NET_ONION; } bool CNetAddr::IsLocal() const { @@ -234,7 +248,7 @@ bool CNetAddr::IsLocal() const // IPv6 loopback (::1/128) static const unsigned char pchLocal[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1}; - if (memcmp(ip, pchLocal, 16) == 0) + if (IsIPv6() && memcmp(ip, pchLocal, 16) == 0) return true; return false; @@ -258,12 +272,12 @@ bool CNetAddr::IsValid() const // header20 vectorlen3 addr26 addr26 addr26 header20 vectorlen3 addr26 addr26 addr26... // so if the first length field is garbled, it reads the second batch // of addr misaligned by 3 bytes. - if (memcmp(ip, pchIPv4+3, sizeof(pchIPv4)-3) == 0) + if (IsIPv6() && memcmp(ip, pchIPv4+3, sizeof(pchIPv4)-3) == 0) return false; // unspecified IPv6 address (::/128) unsigned char ipNone6[16] = {}; - if (memcmp(ip, ipNone6, 16) == 0) + if (IsIPv6() && memcmp(ip, ipNone6, 16) == 0) return false; // documentation IPv6 address @@ -310,7 +324,7 @@ bool CNetAddr::IsRoutable() const */ bool CNetAddr::IsInternal() const { - return memcmp(ip, g_internal_prefix, sizeof(g_internal_prefix)) == 0; + return m_net == NET_INTERNAL; } enum Network CNetAddr::GetNetwork() const @@ -321,13 +335,7 @@ enum Network CNetAddr::GetNetwork() const if (!IsRoutable()) return NET_UNROUTABLE; - if (IsIPv4()) - return NET_IPV4; - - if (IsTor()) - return NET_ONION; - - return NET_IPV6; + return m_net; } std::string CNetAddr::ToStringIP() const @@ -361,12 +369,12 @@ std::string CNetAddr::ToString() const bool operator==(const CNetAddr& a, const CNetAddr& b) { - return (memcmp(a.ip, b.ip, 16) == 0); + return a.m_net == b.m_net && memcmp(a.ip, b.ip, 16) == 0; } bool operator<(const CNetAddr& a, const CNetAddr& b) { - return (memcmp(a.ip, b.ip, 16) < 0); + return a.m_net < b.m_net || (a.m_net == b.m_net && memcmp(a.ip, b.ip, 16) < 0); } /** @@ -545,7 +553,7 @@ std::vector<unsigned char> CNetAddr::GetGroup(const std::vector<bool> &asmap) co uint64_t CNetAddr::GetHash() const { - uint256 hash = Hash(&ip[0], &ip[16]); + uint256 hash = Hash(ip); uint64_t nRet; memcpy(&nRet, &hash, sizeof(nRet)); return nRet; @@ -627,15 +635,15 @@ CService::CService() : port(0) { } -CService::CService(const CNetAddr& cip, unsigned short portIn) : CNetAddr(cip), port(portIn) +CService::CService(const CNetAddr& cip, uint16_t portIn) : CNetAddr(cip), port(portIn) { } -CService::CService(const struct in_addr& ipv4Addr, unsigned short portIn) : CNetAddr(ipv4Addr), port(portIn) +CService::CService(const struct in_addr& ipv4Addr, uint16_t portIn) : CNetAddr(ipv4Addr), port(portIn) { } -CService::CService(const struct in6_addr& ipv6Addr, unsigned short portIn) : CNetAddr(ipv6Addr), port(portIn) +CService::CService(const struct in6_addr& ipv6Addr, uint16_t portIn) : CNetAddr(ipv6Addr), port(portIn) { } @@ -663,7 +671,7 @@ bool CService::SetSockAddr(const struct sockaddr *paddr) } } -unsigned short CService::GetPort() const +uint16_t CService::GetPort() const { return port; } @@ -725,12 +733,10 @@ bool CService::GetSockAddr(struct sockaddr* paddr, socklen_t *addrlen) const */ std::vector<unsigned char> CService::GetKey() const { - std::vector<unsigned char> vKey; - vKey.resize(18); - memcpy(vKey.data(), ip, 16); - vKey[16] = port / 0x100; // most significant byte of our port - vKey[17] = port & 0x0FF; // least significant byte of our port - return vKey; + auto key = GetAddrBytes(); + key.push_back(port / 0x100); // most significant byte of our port + key.push_back(port & 0x0FF); // least significant byte of our port + return key; } std::string CService::ToStringPort() const @@ -814,7 +820,7 @@ CSubNet::CSubNet(const CNetAddr &addr): */ bool CSubNet::Match(const CNetAddr &addr) const { - if (!valid || !addr.IsValid()) + if (!valid || !addr.IsValid() || network.m_net != addr.m_net) return false; for(int x=0; x<16; ++x) if ((addr.ip[x] & netmask[x]) != network.ip[x]) diff --git a/src/netaddress.h b/src/netaddress.h index e640c07d32..0365907d44 100644 --- a/src/netaddress.h +++ b/src/netaddress.h @@ -12,25 +12,54 @@ #include <compat.h> #include <serialize.h> -#include <stdint.h> +#include <cstdint> #include <string> #include <vector> +/** + * A network type. + * @note An address may belong to more than one network, for example `10.0.0.1` + * belongs to both `NET_UNROUTABLE` and `NET_IPV4`. + * Keep these sequential starting from 0 and `NET_MAX` as the last entry. + * We have loops like `for (int i = 0; i < NET_MAX; i++)` that expect to iterate + * over all enum values and also `GetExtNetwork()` "extends" this enum by + * introducing standalone constants starting from `NET_MAX`. + */ enum Network { + /// Addresses from these networks are not publicly routable on the global Internet. NET_UNROUTABLE = 0, + + /// IPv4 NET_IPV4, + + /// IPv6 NET_IPV6, + + /// TORv2 NET_ONION, + + /// A set of dummy addresses that map a name to an IPv6 address. These + /// addresses belong to RFC4193's fc00::/7 subnet (unique-local addresses). + /// We use them to map a string or FQDN to an IPv6 address in CAddrMan to + /// keep track of which DNS seeds were used. NET_INTERNAL, + /// Dummy value to indicate the number of NET_* constants. NET_MAX, }; -/** IP address (IPv6, or IPv4 using mapped IPv6 range (::FFFF:0:0/96)) */ +/** + * Network address. + */ class CNetAddr { protected: + /** + * Network to which this address belongs. + */ + Network m_net{NET_IPV6}; + unsigned char ip[16]; // in network byte order uint32_t scopeId{0}; // for scoped/link-local ipv6 addresses @@ -40,6 +69,14 @@ class CNetAddr void SetIP(const CNetAddr& ip); /** + * Set from a legacy IPv6 address. + * Legacy IPv6 address may be a normal IPv6 address, or another address + * (e.g. IPv4) disguised as IPv6. This encoding is used in the legacy + * `addr` encoding. + */ + void SetLegacyIPv6(const uint8_t ipv6[16]); + + /** * Set raw IPv4 or IPv6 address (in network byte order) * @note Only NET_IPV4 and NET_IPV6 are allowed for network. */ @@ -90,6 +127,7 @@ class CNetAddr uint32_t GetMappedAS(const std::vector<bool> &asmap) const; std::vector<unsigned char> GetGroup(const std::vector<bool> &asmap) const; + std::vector<unsigned char> GetAddrBytes() const { return {std::begin(ip), std::end(ip)}; } int GetReachabilityFrom(const CNetAddr *paddrPartner = nullptr) const; explicit CNetAddr(const struct in6_addr& pipv6Addr, const uint32_t scope = 0); @@ -99,7 +137,27 @@ class CNetAddr friend bool operator!=(const CNetAddr& a, const CNetAddr& b) { return !(a == b); } friend bool operator<(const CNetAddr& a, const CNetAddr& b); - SERIALIZE_METHODS(CNetAddr, obj) { READWRITE(obj.ip); } + /** + * Serialize to a stream. + */ + template <typename Stream> + void Serialize(Stream& s) const + { + s << ip; + } + + /** + * Unserialize from a stream. + */ + template <typename Stream> + void Unserialize(Stream& s) + { + unsigned char ip_temp[sizeof(ip)]; + s >> ip_temp; + // Use SetLegacyIPv6() so that m_net is set correctly. For example + // ::FFFF:0102:0304 should be set as m_net=NET_IPV4 (1.2.3.4). + SetLegacyIPv6(ip_temp); + } friend class CSubNet; }; @@ -142,10 +200,10 @@ class CService : public CNetAddr public: CService(); - CService(const CNetAddr& ip, unsigned short port); - CService(const struct in_addr& ipv4Addr, unsigned short port); + CService(const CNetAddr& ip, uint16_t port); + CService(const struct in_addr& ipv4Addr, uint16_t port); explicit CService(const struct sockaddr_in& addr); - unsigned short GetPort() const; + uint16_t GetPort() const; bool GetSockAddr(struct sockaddr* paddr, socklen_t *addrlen) const; bool SetSockAddr(const struct sockaddr* paddr); friend bool operator==(const CService& a, const CService& b); @@ -156,10 +214,14 @@ class CService : public CNetAddr std::string ToStringPort() const; std::string ToStringIPPort() const; - CService(const struct in6_addr& ipv6Addr, unsigned short port); + CService(const struct in6_addr& ipv6Addr, uint16_t port); explicit CService(const struct sockaddr_in6& addr); - SERIALIZE_METHODS(CService, obj) { READWRITE(obj.ip, Using<BigEndianFormatter<2>>(obj.port)); } + SERIALIZE_METHODS(CService, obj) + { + READWRITEAS(CNetAddr, obj); + READWRITE(Using<BigEndianFormatter<2>>(obj.port)); + } }; bool SanityCheckASMap(const std::vector<bool>& asmap); diff --git a/src/netbase.cpp b/src/netbase.cpp index a70179cb16..3a3b5f3e66 100644 --- a/src/netbase.cpp +++ b/src/netbase.cpp @@ -12,6 +12,7 @@ #include <util/system.h> #include <atomic> +#include <cstdint> #ifndef WIN32 #include <fcntl.h> @@ -28,9 +29,9 @@ #endif // Settings -static RecursiveMutex cs_proxyInfos; -static proxyType proxyInfo[NET_MAX] GUARDED_BY(cs_proxyInfos); -static proxyType nameProxy GUARDED_BY(cs_proxyInfos); +static Mutex g_proxyinfo_mutex; +static proxyType proxyInfo[NET_MAX] GUARDED_BY(g_proxyinfo_mutex); +static proxyType nameProxy GUARDED_BY(g_proxyinfo_mutex); int nConnectTimeout = DEFAULT_CONNECT_TIMEOUT; bool fNameLookup = DEFAULT_NAME_LOOKUP; @@ -711,14 +712,14 @@ bool SetProxy(enum Network net, const proxyType &addrProxy) { assert(net >= 0 && net < NET_MAX); if (!addrProxy.IsValid()) return false; - LOCK(cs_proxyInfos); + LOCK(g_proxyinfo_mutex); proxyInfo[net] = addrProxy; return true; } bool GetProxy(enum Network net, proxyType &proxyInfoOut) { assert(net >= 0 && net < NET_MAX); - LOCK(cs_proxyInfos); + LOCK(g_proxyinfo_mutex); if (!proxyInfo[net].IsValid()) return false; proxyInfoOut = proxyInfo[net]; @@ -744,13 +745,13 @@ bool GetProxy(enum Network net, proxyType &proxyInfoOut) { bool SetNameProxy(const proxyType &addrProxy) { if (!addrProxy.IsValid()) return false; - LOCK(cs_proxyInfos); + LOCK(g_proxyinfo_mutex); nameProxy = addrProxy; return true; } bool GetNameProxy(proxyType &nameProxyOut) { - LOCK(cs_proxyInfos); + LOCK(g_proxyinfo_mutex); if(!nameProxy.IsValid()) return false; nameProxyOut = nameProxy; @@ -758,12 +759,12 @@ bool GetNameProxy(proxyType &nameProxyOut) { } bool HaveNameProxy() { - LOCK(cs_proxyInfos); + LOCK(g_proxyinfo_mutex); return nameProxy.IsValid(); } bool IsProxy(const CNetAddr &addr) { - LOCK(cs_proxyInfos); + LOCK(g_proxyinfo_mutex); for (int i = 0; i < NET_MAX; i++) { if (addr == static_cast<CNetAddr>(proxyInfo[i].proxy)) return true; @@ -798,11 +799,11 @@ bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int ProxyCredentials random_auth; static std::atomic_int counter(0); random_auth.username = random_auth.password = strprintf("%i", counter++); - if (!Socks5(strDest, (unsigned short)port, &random_auth, hSocket)) { + if (!Socks5(strDest, (uint16_t)port, &random_auth, hSocket)) { return false; } } else { - if (!Socks5(strDest, (unsigned short)port, 0, hSocket)) { + if (!Socks5(strDest, (uint16_t)port, 0, hSocket)) { return false; } } diff --git a/src/netmessagemaker.h b/src/netmessagemaker.h index 2efb384a7b..ffb3fe2f29 100644 --- a/src/netmessagemaker.h +++ b/src/netmessagemaker.h @@ -15,18 +15,18 @@ public: explicit CNetMsgMaker(int nVersionIn) : nVersion(nVersionIn){} template <typename... Args> - CSerializedNetMsg Make(int nFlags, std::string sCommand, Args&&... args) const + CSerializedNetMsg Make(int nFlags, std::string msg_type, Args&&... args) const { CSerializedNetMsg msg; - msg.command = std::move(sCommand); + msg.m_type = std::move(msg_type); CVectorWriter{ SER_NETWORK, nFlags | nVersion, msg.data, 0, std::forward<Args>(args)... }; return msg; } template <typename... Args> - CSerializedNetMsg Make(std::string sCommand, Args&&... args) const + CSerializedNetMsg Make(std::string msg_type, Args&&... args) const { - return Make(0, std::move(sCommand), std::forward<Args>(args)...); + return Make(0, std::move(msg_type), std::forward<Args>(args)...); } private: diff --git a/src/node/coinstats.cpp b/src/node/coinstats.cpp index e3c4c828b6..fb46ea1731 100644 --- a/src/node/coinstats.cpp +++ b/src/node/coinstats.cpp @@ -8,13 +8,23 @@ #include <coins.h> #include <hash.h> #include <serialize.h> -#include <validation.h> #include <uint256.h> #include <util/system.h> +#include <validation.h> #include <map> -static void ApplyStats(CCoinsStats &stats, CHashWriter& ss, const uint256& hash, const std::map<uint32_t, Coin>& outputs) +static uint64_t GetBogoSize(const CScript& scriptPubKey) +{ + return 32 /* txid */ + + 4 /* vout index */ + + 4 /* height + coinbase */ + + 8 /* amount */ + + 2 /* scriptPubKey len */ + + scriptPubKey.size() /* scriptPubKey */; +} + +static void ApplyStats(CCoinsStats& stats, CHashWriter& ss, const uint256& hash, const std::map<uint32_t, Coin>& outputs) { assert(!outputs.empty()); ss << hash; @@ -26,26 +36,38 @@ static void ApplyStats(CCoinsStats &stats, CHashWriter& ss, const uint256& hash, ss << VARINT_MODE(output.second.out.nValue, VarIntMode::NONNEGATIVE_SIGNED); stats.nTransactionOutputs++; stats.nTotalAmount += output.second.out.nValue; - stats.nBogoSize += 32 /* txid */ + 4 /* vout index */ + 4 /* height + coinbase */ + 8 /* amount */ + - 2 /* scriptPubKey len */ + output.second.out.scriptPubKey.size() /* scriptPubKey */; + stats.nBogoSize += GetBogoSize(output.second.out.scriptPubKey); } ss << VARINT(0u); } +static void ApplyStats(CCoinsStats& stats, std::nullptr_t, const uint256& hash, const std::map<uint32_t, Coin>& outputs) +{ + assert(!outputs.empty()); + stats.nTransactions++; + for (const auto& output : outputs) { + stats.nTransactionOutputs++; + stats.nTotalAmount += output.second.out.nValue; + stats.nBogoSize += GetBogoSize(output.second.out.scriptPubKey); + } +} + //! Calculate statistics about the unspent transaction output set -bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats, const std::function<void()>& interruption_point) +template <typename T> +static bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats, T hash_obj, const std::function<void()>& interruption_point) { stats = CCoinsStats(); std::unique_ptr<CCoinsViewCursor> pcursor(view->Cursor()); assert(pcursor); - CHashWriter ss(SER_GETHASH, PROTOCOL_VERSION); stats.hashBlock = pcursor->GetBestBlock(); { LOCK(cs_main); stats.nHeight = LookupBlockIndex(stats.hashBlock)->nHeight; } - ss << stats.hashBlock; + + PrepareHash(hash_obj, stats); + uint256 prevkey; std::map<uint32_t, Coin> outputs; while (pcursor->Valid()) { @@ -54,7 +76,7 @@ bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats, const std::function<void Coin coin; if (pcursor->GetKey(key) && pcursor->GetValue(coin)) { if (!outputs.empty() && key.hash != prevkey) { - ApplyStats(stats, ss, prevkey, outputs); + ApplyStats(stats, hash_obj, prevkey, outputs); outputs.clear(); } prevkey = key.hash; @@ -66,9 +88,38 @@ bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats, const std::function<void pcursor->Next(); } if (!outputs.empty()) { - ApplyStats(stats, ss, prevkey, outputs); + ApplyStats(stats, hash_obj, prevkey, outputs); } - stats.hashSerialized = ss.GetHash(); + + FinalizeHash(hash_obj, stats); + stats.nDiskSize = view->EstimateSize(); return true; } + +bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats, CoinStatsHashType hash_type, const std::function<void()>& interruption_point) +{ + switch (hash_type) { + case(CoinStatsHashType::HASH_SERIALIZED): { + CHashWriter ss(SER_GETHASH, PROTOCOL_VERSION); + return GetUTXOStats(view, stats, ss, interruption_point); + } + case(CoinStatsHashType::NONE): { + return GetUTXOStats(view, stats, nullptr, interruption_point); + } + } // no default case, so the compiler can warn about missing cases + assert(false); +} + +// The legacy hash serializes the hashBlock +static void PrepareHash(CHashWriter& ss, CCoinsStats& stats) +{ + ss << stats.hashBlock; +} +static void PrepareHash(std::nullptr_t, CCoinsStats& stats) {} + +static void FinalizeHash(CHashWriter& ss, CCoinsStats& stats) +{ + stats.hashSerialized = ss.GetHash(); +} +static void FinalizeHash(std::nullptr_t, CCoinsStats& stats) {} diff --git a/src/node/coinstats.h b/src/node/coinstats.h index d9cdaa3036..2a7441c10e 100644 --- a/src/node/coinstats.h +++ b/src/node/coinstats.h @@ -14,6 +14,11 @@ class CCoinsView; +enum class CoinStatsHashType { + HASH_SERIALIZED, + NONE, +}; + struct CCoinsStats { int nHeight{0}; @@ -30,6 +35,6 @@ struct CCoinsStats }; //! Calculate statistics about the unspent transaction output set -bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats, const std::function<void()>& interruption_point = {}); +bool GetUTXOStats(CCoinsView* view, CCoinsStats& stats, const CoinStatsHashType hash_type, const std::function<void()>& interruption_point = {}); #endif // BITCOIN_NODE_COINSTATS_H diff --git a/src/node/context.h b/src/node/context.h index c45d9e6689..be568cba36 100644 --- a/src/node/context.h +++ b/src/node/context.h @@ -6,6 +6,7 @@ #define BITCOIN_NODE_CONTEXT_H #include <cassert> +#include <functional> #include <memory> #include <vector> @@ -41,6 +42,7 @@ struct NodeContext { std::unique_ptr<interfaces::Chain> chain; std::vector<std::unique_ptr<interfaces::ChainClient>> chain_clients; std::unique_ptr<CScheduler> scheduler; + std::function<void()> rpc_interruption_point = [] {}; //! Declare default constructor and destructor that are not inline, so code //! instantiating the NodeContext struct doesn't need to #include class @@ -49,10 +51,4 @@ struct NodeContext { ~NodeContext(); }; -inline ChainstateManager& EnsureChainman(const NodeContext& node) -{ - assert(node.chainman); - return *node.chainman; -} - #endif // BITCOIN_NODE_CONTEXT_H diff --git a/src/node/transaction.cpp b/src/node/transaction.cpp index 3841d8687d..5633abe817 100644 --- a/src/node/transaction.cpp +++ b/src/node/transaction.cpp @@ -80,9 +80,10 @@ TransactionError BroadcastTransaction(NodeContext& node, const CTransactionRef t if (relay) { // the mempool tracks locally submitted transactions to make a // best-effort of initial broadcast - node.mempool->AddUnbroadcastTx(hashTx); + node.mempool->AddUnbroadcastTx(hashTx, tx->GetWitnessHash()); - RelayTransaction(hashTx, *node.connman); + LOCK(cs_main); + RelayTransaction(hashTx, tx->GetWitnessHash(), *node.connman); } return TransactionError::OK; diff --git a/src/ui_interface.cpp b/src/node/ui_interface.cpp index 15795bd67f..8d3665975d 100644 --- a/src/ui_interface.cpp +++ b/src/node/ui_interface.cpp @@ -2,18 +2,18 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. -#include <ui_interface.h> +#include <node/ui_interface.h> #include <util/translation.h> -#include <boost/signals2/last_value.hpp> +#include <boost/signals2/optional_last_value.hpp> #include <boost/signals2/signal.hpp> CClientUIInterface uiInterface; struct UISignals { - boost::signals2::signal<CClientUIInterface::ThreadSafeMessageBoxSig, boost::signals2::last_value<bool>> ThreadSafeMessageBox; - boost::signals2::signal<CClientUIInterface::ThreadSafeQuestionSig, boost::signals2::last_value<bool>> ThreadSafeQuestion; + boost::signals2::signal<CClientUIInterface::ThreadSafeMessageBoxSig, boost::signals2::optional_last_value<bool>> ThreadSafeMessageBox; + boost::signals2::signal<CClientUIInterface::ThreadSafeQuestionSig, boost::signals2::optional_last_value<bool>> ThreadSafeQuestion; boost::signals2::signal<CClientUIInterface::InitMessageSig> InitMessage; boost::signals2::signal<CClientUIInterface::NotifyNumConnectionsChangedSig> NotifyNumConnectionsChanged; boost::signals2::signal<CClientUIInterface::NotifyNetworkActiveChangedSig> NotifyNetworkActiveChanged; @@ -42,8 +42,8 @@ ADD_SIGNALS_IMPL_WRAPPER(NotifyBlockTip); ADD_SIGNALS_IMPL_WRAPPER(NotifyHeaderTip); ADD_SIGNALS_IMPL_WRAPPER(BannedListChanged); -bool CClientUIInterface::ThreadSafeMessageBox(const bilingual_str& message, const std::string& caption, unsigned int style) { return g_ui_signals.ThreadSafeMessageBox(message, caption, style); } -bool CClientUIInterface::ThreadSafeQuestion(const bilingual_str& message, const std::string& non_interactive_message, const std::string& caption, unsigned int style) { return g_ui_signals.ThreadSafeQuestion(message, non_interactive_message, caption, style); } +bool CClientUIInterface::ThreadSafeMessageBox(const bilingual_str& message, const std::string& caption, unsigned int style) { return g_ui_signals.ThreadSafeMessageBox(message, caption, style).value_or(false);} +bool CClientUIInterface::ThreadSafeQuestion(const bilingual_str& message, const std::string& non_interactive_message, const std::string& caption, unsigned int style) { return g_ui_signals.ThreadSafeQuestion(message, non_interactive_message, caption, style).value_or(false);} void CClientUIInterface::InitMessage(const std::string& message) { return g_ui_signals.InitMessage(message); } void CClientUIInterface::NotifyNumConnectionsChanged(int newNumConnections) { return g_ui_signals.NotifyNumConnectionsChanged(newNumConnections); } void CClientUIInterface::NotifyNetworkActiveChanged(bool networkActive) { return g_ui_signals.NotifyNetworkActiveChanged(networkActive); } diff --git a/src/ui_interface.h b/src/node/ui_interface.h index d45811178f..d574ab879f 100644 --- a/src/ui_interface.h +++ b/src/node/ui_interface.h @@ -3,8 +3,8 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. -#ifndef BITCOIN_UI_INTERFACE_H -#define BITCOIN_UI_INTERFACE_H +#ifndef BITCOIN_NODE_UI_INTERFACE_H +#define BITCOIN_NODE_UI_INTERFACE_H #include <functional> #include <memory> @@ -20,14 +20,6 @@ class connection; } } // namespace boost -/** General change type (added, updated, removed). */ -enum ChangeType -{ - CT_NEW, - CT_UPDATED, - CT_DELETED -}; - /** Signals for UI communication. */ class CClientUIInterface { @@ -67,9 +59,6 @@ public: /** Force blocking, modal message box dialog (not just OS notification) */ MODAL = 0x10000000U, - /** Do not prepend error/warning prefix */ - MSG_NOPREFIX = 0x20000000U, - /** Do not print contents of message to debug log */ SECURE = 0x40000000U, @@ -125,7 +114,8 @@ void InitWarning(const bilingual_str& str); /** Show error message **/ bool InitError(const bilingual_str& str); +constexpr auto AbortError = InitError; extern CClientUIInterface uiInterface; -#endif // BITCOIN_UI_INTERFACE_H +#endif // BITCOIN_NODE_UI_INTERFACE_H diff --git a/src/noui.cpp b/src/noui.cpp index ddb3a50ff7..3c82512fac 100644 --- a/src/noui.cpp +++ b/src/noui.cpp @@ -6,7 +6,7 @@ #include <noui.h> #include <logging.h> -#include <ui_interface.h> +#include <node/ui_interface.h> #include <util/translation.h> #include <string> @@ -23,24 +23,20 @@ bool noui_ThreadSafeMessageBox(const bilingual_str& message, const std::string& { bool fSecure = style & CClientUIInterface::SECURE; style &= ~CClientUIInterface::SECURE; - bool prefix = !(style & CClientUIInterface::MSG_NOPREFIX); - style &= ~CClientUIInterface::MSG_NOPREFIX; std::string strCaption; - if (prefix) { - switch (style) { - case CClientUIInterface::MSG_ERROR: - strCaption = "Error: "; - break; - case CClientUIInterface::MSG_WARNING: - strCaption = "Warning: "; - break; - case CClientUIInterface::MSG_INFORMATION: - strCaption = "Information: "; - break; - default: - strCaption = caption + ": "; // Use supplied caption (can be empty) - } + switch (style) { + case CClientUIInterface::MSG_ERROR: + strCaption = "Error: "; + break; + case CClientUIInterface::MSG_WARNING: + strCaption = "Warning: "; + break; + case CClientUIInterface::MSG_INFORMATION: + strCaption = "Information: "; + break; + default: + strCaption = caption + ": "; // Use supplied caption (can be empty) } if (!fSecure) { diff --git a/src/outputtype.cpp b/src/outputtype.cpp index ea7a86d6d6..e978852826 100644 --- a/src/outputtype.cpp +++ b/src/outputtype.cpp @@ -42,8 +42,8 @@ const std::string& FormatOutputType(OutputType type) case OutputType::LEGACY: return OUTPUT_TYPE_STRING_LEGACY; case OutputType::P2SH_SEGWIT: return OUTPUT_TYPE_STRING_P2SH_SEGWIT; case OutputType::BECH32: return OUTPUT_TYPE_STRING_BECH32; - default: assert(false); - } + } // no default case, so the compiler can warn about missing cases + assert(false); } CTxDestination GetDestinationForKey(const CPubKey& key, OutputType type) @@ -53,7 +53,7 @@ CTxDestination GetDestinationForKey(const CPubKey& key, OutputType type) case OutputType::P2SH_SEGWIT: case OutputType::BECH32: { if (!key.IsCompressed()) return PKHash(key); - CTxDestination witdest = WitnessV0KeyHash(PKHash(key)); + CTxDestination witdest = WitnessV0KeyHash(key); CScript witprog = GetScriptForDestination(witdest); if (type == OutputType::P2SH_SEGWIT) { return ScriptHash(witprog); @@ -61,8 +61,8 @@ CTxDestination GetDestinationForKey(const CPubKey& key, OutputType type) return witdest; } } - default: assert(false); - } + } // no default case, so the compiler can warn about missing cases + assert(false); } std::vector<CTxDestination> GetAllDestinationsForKey(const CPubKey& key) @@ -100,6 +100,6 @@ CTxDestination AddAndGetDestinationForScript(FillableSigningProvider& keystore, return ScriptHash(witprog); } } - default: assert(false); - } + } // no default case, so the compiler can warn about missing cases + assert(false); } diff --git a/src/outputtype.h b/src/outputtype.h index 1438f65844..77a16b1d05 100644 --- a/src/outputtype.h +++ b/src/outputtype.h @@ -18,14 +18,6 @@ enum class OutputType { LEGACY, P2SH_SEGWIT, BECH32, - - /** - * Special output type for change outputs only. Automatically choose type - * based on address type setting and the types other of non-change outputs - * (see -changetype option documentation and implementation in - * CWallet::TransactionChangeType for details). - */ - CHANGE_AUTO, }; extern const std::array<OutputType, 3> OUTPUT_TYPES; diff --git a/src/policy/feerate.cpp b/src/policy/feerate.cpp index 14be6192fe..a01e259731 100644 --- a/src/policy/feerate.cpp +++ b/src/policy/feerate.cpp @@ -7,8 +7,6 @@ #include <tinyformat.h> -const std::string CURRENCY_UNIT = "BTC"; - CFeeRate::CFeeRate(const CAmount& nFeePaid, size_t nBytes_) { assert(nBytes_ <= uint64_t(std::numeric_limits<int64_t>::max())); @@ -37,7 +35,10 @@ CAmount CFeeRate::GetFee(size_t nBytes_) const return nFee; } -std::string CFeeRate::ToString() const +std::string CFeeRate::ToString(const FeeEstimateMode& fee_estimate_mode) const { - return strprintf("%d.%08d %s/kB", nSatoshisPerK / COIN, nSatoshisPerK % COIN, CURRENCY_UNIT); + switch (fee_estimate_mode) { + case FeeEstimateMode::SAT_B: return strprintf("%d.%03d %s/B", nSatoshisPerK / 1000, nSatoshisPerK % 1000, CURRENCY_ATOM); + default: return strprintf("%d.%08d %s/kB", nSatoshisPerK / COIN, nSatoshisPerK % COIN, CURRENCY_UNIT); + } } diff --git a/src/policy/feerate.h b/src/policy/feerate.h index 61fa80c130..883940f73c 100644 --- a/src/policy/feerate.h +++ b/src/policy/feerate.h @@ -11,7 +11,17 @@ #include <string> -extern const std::string CURRENCY_UNIT; +const std::string CURRENCY_UNIT = "BTC"; // One formatted unit +const std::string CURRENCY_ATOM = "sat"; // One indivisible minimum value unit + +/* Used to determine type of fee estimation requested */ +enum class FeeEstimateMode { + UNSET, //!< Use default settings based on other criteria + ECONOMICAL, //!< Force estimateSmartFee to use non-conservative estimates + CONSERVATIVE, //!< Force estimateSmartFee to use conservative estimates + BTC_KB, //!< Use explicit BTC/kB fee given in coin control + SAT_B, //!< Use explicit sat/B fee given in coin control +}; /** * Fee rate in satoshis per kilobyte: CAmount / kB @@ -46,7 +56,7 @@ public: friend bool operator>=(const CFeeRate& a, const CFeeRate& b) { return a.nSatoshisPerK >= b.nSatoshisPerK; } friend bool operator!=(const CFeeRate& a, const CFeeRate& b) { return a.nSatoshisPerK != b.nSatoshisPerK; } CFeeRate& operator+=(const CFeeRate& a) { nSatoshisPerK += a.nSatoshisPerK; return *this; } - std::string ToString() const; + std::string ToString(const FeeEstimateMode& fee_estimate_mode = FeeEstimateMode::BTC_KB) const; SERIALIZE_METHODS(CFeeRate, obj) { READWRITE(obj.nSatoshisPerK); } }; diff --git a/src/policy/fees.h b/src/policy/fees.h index 6ee6e0d547..e79dbc9868 100644 --- a/src/policy/fees.h +++ b/src/policy/fees.h @@ -45,13 +45,6 @@ enum class FeeReason { REQUIRED, }; -/* Used to determine type of fee estimation requested */ -enum class FeeEstimateMode { - UNSET, //!< Use default settings based on other criteria - ECONOMICAL, //!< Force estimateSmartFee to use non-conservative estimates - CONSERVATIVE, //!< Force estimateSmartFee to use conservative estimates -}; - /* Used to return detailed information about a feerate bucket */ struct EstimatorBucket { @@ -280,7 +273,7 @@ public: /** Create new FeeFilterRounder */ explicit FeeFilterRounder(const CFeeRate& minIncrementalFee); - /** Quantize a minimum fee for privacy purpose before broadcast **/ + /** Quantize a minimum fee for privacy purpose before broadcast. Not thread-safe due to use of FastRandomContext */ CAmount round(CAmount currentMinFee); private: diff --git a/src/policy/policy.cpp b/src/policy/policy.cpp index 07d51c0088..0e9820da1e 100644 --- a/src/policy/policy.cpp +++ b/src/policy/policy.cpp @@ -50,14 +50,14 @@ bool IsDust(const CTxOut& txout, const CFeeRate& dustRelayFeeIn) return (txout.nValue < GetDustThreshold(txout, dustRelayFeeIn)); } -bool IsStandard(const CScript& scriptPubKey, txnouttype& whichType) +bool IsStandard(const CScript& scriptPubKey, TxoutType& whichType) { std::vector<std::vector<unsigned char> > vSolutions; whichType = Solver(scriptPubKey, vSolutions); - if (whichType == TX_NONSTANDARD) { + if (whichType == TxoutType::NONSTANDARD) { return false; - } else if (whichType == TX_MULTISIG) { + } else if (whichType == TxoutType::MULTISIG) { unsigned char m = vSolutions.front()[0]; unsigned char n = vSolutions.back()[0]; // Support up to x-of-3 multisig txns as standard @@ -65,7 +65,7 @@ bool IsStandard(const CScript& scriptPubKey, txnouttype& whichType) return false; if (m < 1 || m > n) return false; - } else if (whichType == TX_NULL_DATA && + } else if (whichType == TxoutType::NULL_DATA && (!fAcceptDatacarrier || scriptPubKey.size() > nMaxDatacarrierBytes)) { return false; } @@ -110,16 +110,16 @@ bool IsStandardTx(const CTransaction& tx, bool permit_bare_multisig, const CFeeR } unsigned int nDataOut = 0; - txnouttype whichType; + TxoutType whichType; for (const CTxOut& txout : tx.vout) { if (!::IsStandard(txout.scriptPubKey, whichType)) { reason = "scriptpubkey"; return false; } - if (whichType == TX_NULL_DATA) + if (whichType == TxoutType::NULL_DATA) nDataOut++; - else if ((whichType == TX_MULTISIG) && (!permit_bare_multisig)) { + else if ((whichType == TxoutType::MULTISIG) && (!permit_bare_multisig)) { reason = "bare-multisig"; return false; } else if (IsDust(txout, dust_relay_fee)) { @@ -152,6 +152,8 @@ bool IsStandardTx(const CTransaction& tx, bool permit_bare_multisig, const CFeeR * script can be anything; an attacker could use a very * expensive-to-check-upon-redemption script like: * DUP CHECKSIG DROP ... repeated 100 times... OP_1 + * + * Note that only the non-witness portion of the transaction is checked here. */ bool AreInputsStandard(const CTransaction& tx, const CCoinsViewCache& mapInputs) { @@ -163,10 +165,14 @@ bool AreInputsStandard(const CTransaction& tx, const CCoinsViewCache& mapInputs) const CTxOut& prev = mapInputs.AccessCoin(tx.vin[i].prevout).out; std::vector<std::vector<unsigned char> > vSolutions; - txnouttype whichType = Solver(prev.scriptPubKey, vSolutions); - if (whichType == TX_NONSTANDARD) { + TxoutType whichType = Solver(prev.scriptPubKey, vSolutions); + if (whichType == TxoutType::NONSTANDARD || whichType == TxoutType::WITNESS_UNKNOWN) { + // WITNESS_UNKNOWN failures are typically also caught with a policy + // flag in the script interpreter, but it can be helpful to catch + // this type of NONSTANDARD transaction earlier in transaction + // validation. return false; - } else if (whichType == TX_SCRIPTHASH) { + } else if (whichType == TxoutType::SCRIPTHASH) { std::vector<std::vector<unsigned char> > stack; // convert the scriptSig into a stack, so we can inspect the redeemScript if (!EvalScript(stack, tx.vin[i].scriptSig, SCRIPT_VERIFY_NONE, BaseSignatureChecker(), SigVersion::BASE)) diff --git a/src/policy/policy.h b/src/policy/policy.h index 1561a41c5e..7f168ee20f 100644 --- a/src/policy/policy.h +++ b/src/policy/policy.h @@ -81,7 +81,7 @@ CAmount GetDustThreshold(const CTxOut& txout, const CFeeRate& dustRelayFee); bool IsDust(const CTxOut& txout, const CFeeRate& dustRelayFee); -bool IsStandard(const CScript& scriptPubKey, txnouttype& whichType); +bool IsStandard(const CScript& scriptPubKey, TxoutType& whichType); /** * Check for standard transaction types * @return True if all outputs (scriptPubKeys) use only standard transaction forms diff --git a/src/primitives/transaction.h b/src/primitives/transaction.h index 4514db578a..544bab6d9b 100644 --- a/src/primitives/transaction.h +++ b/src/primitives/transaction.h @@ -12,6 +12,8 @@ #include <serialize.h> #include <uint256.h> +#include <tuple> + static const int SERIALIZE_TRANSACTION_NO_WITNESS = 0x40000000; /** An outpoint - a combination of a transaction hash and an index n into its vout */ @@ -388,4 +390,17 @@ typedef std::shared_ptr<const CTransaction> CTransactionRef; static inline CTransactionRef MakeTransactionRef() { return std::make_shared<const CTransaction>(); } template <typename Tx> static inline CTransactionRef MakeTransactionRef(Tx&& txIn) { return std::make_shared<const CTransaction>(std::forward<Tx>(txIn)); } +/** A generic txid reference (txid or wtxid). */ +class GenTxid +{ + const bool m_is_wtxid; + const uint256 m_hash; +public: + GenTxid(bool is_wtxid, const uint256& hash) : m_is_wtxid(is_wtxid), m_hash(hash) {} + bool IsWtxid() const { return m_is_wtxid; } + const uint256& GetHash() const { return m_hash; } + friend bool operator==(const GenTxid& a, const GenTxid& b) { return a.m_is_wtxid == b.m_is_wtxid && a.m_hash == b.m_hash; } + friend bool operator<(const GenTxid& a, const GenTxid& b) { return std::tie(a.m_is_wtxid, a.m_hash) < std::tie(b.m_is_wtxid, b.m_hash); } +}; + #endif // BITCOIN_PRIMITIVES_TRANSACTION_H diff --git a/src/protocol.cpp b/src/protocol.cpp index 7f58125f00..c989aa3902 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -5,8 +5,8 @@ #include <protocol.h> -#include <util/system.h> #include <util/strencodings.h> +#include <util/system.h> #ifndef WIN32 # include <arpa/inet.h> @@ -46,6 +46,7 @@ const char *GETCFHEADERS="getcfheaders"; const char *CFHEADERS="cfheaders"; const char *GETCFCHECKPT="getcfcheckpt"; const char *CFCHECKPT="cfcheckpt"; +const char *WTXIDRELAY="wtxidrelay"; } // namespace NetMsgType /** All known message types. Keep this in the same order as the list of @@ -83,6 +84,7 @@ const static std::string allNetMessageTypes[] = { NetMsgType::CFHEADERS, NetMsgType::GETCFCHECKPT, NetMsgType::CFCHECKPT, + NetMsgType::WTXIDRELAY, }; const static std::vector<std::string> allNetMessageTypesVec(allNetMessageTypes, allNetMessageTypes+ARRAYLEN(allNetMessageTypes)); @@ -177,6 +179,8 @@ std::string CInv::GetCommand() const switch (masked) { case MSG_TX: return cmd.append(NetMsgType::TX); + // WTX is not a message type, just an inv type + case MSG_WTX: return cmd.append("wtx"); case MSG_BLOCK: return cmd.append(NetMsgType::BLOCK); case MSG_FILTERED_BLOCK: return cmd.append(NetMsgType::MERKLEBLOCK); case MSG_CMPCT_BLOCK: return cmd.append(NetMsgType::CMPCTBLOCK); @@ -221,11 +225,7 @@ static std::string serviceFlagToStr(size_t bit) std::ostringstream stream; stream.imbue(std::locale::classic()); stream << "UNKNOWN["; - if (bit < 8) { - stream << service_flag; - } else { - stream << "2^" << bit; - } + stream << "2^" << bit; stream << "]"; return stream.str(); } @@ -242,3 +242,9 @@ std::vector<std::string> serviceFlagsToStr(uint64_t flags) return str_flags; } + +GenTxid ToGenTxid(const CInv& inv) +{ + assert(inv.IsGenTxMsg()); + return {inv.IsMsgWtx(), inv.hash}; +} diff --git a/src/protocol.h b/src/protocol.h index a68f30287d..0bef12ee62 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -11,6 +11,7 @@ #define BITCOIN_PROTOCOL_H #include <netaddress.h> +#include <primitives/transaction.h> #include <serialize.h> #include <uint256.h> #include <version.h> @@ -63,100 +64,84 @@ namespace NetMsgType { /** * The version message provides information about the transmitting node to the * receiving node at the beginning of a connection. - * @see https://bitcoin.org/en/developer-reference#version */ extern const char* VERSION; /** * The verack message acknowledges a previously-received version message, * informing the connecting node that it can begin to send other messages. - * @see https://bitcoin.org/en/developer-reference#verack */ extern const char* VERACK; /** * The addr (IP address) message relays connection information for peers on the * network. - * @see https://bitcoin.org/en/developer-reference#addr */ extern const char* ADDR; /** * The inv message (inventory message) transmits one or more inventories of * objects known to the transmitting peer. - * @see https://bitcoin.org/en/developer-reference#inv */ extern const char* INV; /** * The getdata message requests one or more data objects from another node. - * @see https://bitcoin.org/en/developer-reference#getdata */ extern const char* GETDATA; /** * The merkleblock message is a reply to a getdata message which requested a * block using the inventory type MSG_MERKLEBLOCK. * @since protocol version 70001 as described by BIP37. - * @see https://bitcoin.org/en/developer-reference#merkleblock */ extern const char* MERKLEBLOCK; /** * The getblocks message requests an inv message that provides block header * hashes starting from a particular point in the block chain. - * @see https://bitcoin.org/en/developer-reference#getblocks */ extern const char* GETBLOCKS; /** * The getheaders message requests a headers message that provides block * headers starting from a particular point in the block chain. * @since protocol version 31800. - * @see https://bitcoin.org/en/developer-reference#getheaders */ extern const char* GETHEADERS; /** * The tx message transmits a single transaction. - * @see https://bitcoin.org/en/developer-reference#tx */ extern const char* TX; /** * The headers message sends one or more block headers to a node which * previously requested certain headers with a getheaders message. * @since protocol version 31800. - * @see https://bitcoin.org/en/developer-reference#headers */ extern const char* HEADERS; /** * The block message transmits a single serialized block. - * @see https://bitcoin.org/en/developer-reference#block */ extern const char* BLOCK; /** * The getaddr message requests an addr message from the receiving node, * preferably one with lots of IP addresses of other receiving nodes. - * @see https://bitcoin.org/en/developer-reference#getaddr */ extern const char* GETADDR; /** * The mempool message requests the TXIDs of transactions that the receiving * node has verified as valid but which have not yet appeared in a block. * @since protocol version 60002. - * @see https://bitcoin.org/en/developer-reference#mempool */ extern const char* MEMPOOL; /** * The ping message is sent periodically to help confirm that the receiving * peer is still connected. - * @see https://bitcoin.org/en/developer-reference#ping */ extern const char* PING; /** * The pong message replies to a ping message, proving to the pinging node that * the ponging node is still alive. * @since protocol version 60001 as described by BIP31. - * @see https://bitcoin.org/en/developer-reference#pong */ extern const char* PONG; /** * The notfound message is a reply to a getdata message which requested an * object the receiving node does not have available for relay. * @since protocol version 70001. - * @see https://bitcoin.org/en/developer-reference#notfound */ extern const char* NOTFOUND; /** @@ -165,7 +150,6 @@ extern const char* NOTFOUND; * @since protocol version 70001 as described by BIP37. * Only available with service bit NODE_BLOOM since protocol version * 70011 as described by BIP111. - * @see https://bitcoin.org/en/developer-reference#filterload */ extern const char* FILTERLOAD; /** @@ -174,7 +158,6 @@ extern const char* FILTERLOAD; * @since protocol version 70001 as described by BIP37. * Only available with service bit NODE_BLOOM since protocol version * 70011 as described by BIP111. - * @see https://bitcoin.org/en/developer-reference#filteradd */ extern const char* FILTERADD; /** @@ -183,14 +166,12 @@ extern const char* FILTERADD; * @since protocol version 70001 as described by BIP37. * Only available with service bit NODE_BLOOM since protocol version * 70011 as described by BIP111. - * @see https://bitcoin.org/en/developer-reference#filterclear */ extern const char* FILTERCLEAR; /** * Indicates that a node prefers to receive new block announcements via a * "headers" message rather than an "inv". * @since protocol version 70012 as described by BIP130. - * @see https://bitcoin.org/en/developer-reference#sendheaders */ extern const char* SENDHEADERS; /** @@ -261,6 +242,12 @@ extern const char* GETCFCHECKPT; * evenly spaced filter headers for blocks on the requested chain. */ extern const char* CFCHECKPT; +/** + * Indicates that a node prefers to relay transactions via wtxid, rather than + * txid. + * @since protocol version 70016 as described by BIP 339. + */ +extern const char *WTXIDRELAY; }; // namespace NetMsgType /* Get a vector of all valid message types (see above) */ @@ -374,7 +361,13 @@ public: READWRITE(nVersion); } if ((s.GetType() & SER_DISK) || - (nVersion >= CADDR_TIME_VERSION && !(s.GetType() & SER_GETHASH))) { + (nVersion != INIT_PROTO_VERSION && !(s.GetType() & SER_GETHASH))) { + // The only time we serialize a CAddress object without nTime is in + // the initial VERSION messages which contain two CAddress records. + // At that point, the serialization version is INIT_PROTO_VERSION. + // After the version handshake, serialization version is >= + // MIN_PEER_PROTO_VERSION and all ADDR messages are serialized with + // nTime. READWRITE(obj.nTime); } READWRITE(Using<CustomUintFormatter<8>>(obj.nServices)); @@ -394,11 +387,12 @@ const uint32_t MSG_TYPE_MASK = 0xffffffff >> 2; * These numbers are defined by the protocol. When adding a new value, be sure * to mention it in the respective BIP. */ -enum GetDataMsg { +enum GetDataMsg : uint32_t { UNDEFINED = 0, MSG_TX = 1, MSG_BLOCK = 2, - // The following can only occur in getdata. Invs always use TX or BLOCK. + MSG_WTX = 5, //!< Defined in BIP 339 + // The following can only occur in getdata. Invs always use TX/WTX or BLOCK. MSG_FILTERED_BLOCK = 3, //!< Defined in BIP37 MSG_CMPCT_BLOCK = 4, //!< Defined in BIP152 MSG_WITNESS_BLOCK = MSG_BLOCK | MSG_WITNESS_FLAG, //!< Defined in BIP144 @@ -420,8 +414,19 @@ public: std::string GetCommand() const; std::string ToString() const; + // Single-message helper methods + bool IsMsgTx() const { return type == MSG_TX; } + bool IsMsgWtx() const { return type == MSG_WTX; } + bool IsMsgWitnessTx() const { return type == MSG_WITNESS_TX; } + + // Combined-message helper methods + bool IsGenTxMsg() const { return type == MSG_TX || type == MSG_WTX || type == MSG_WITNESS_TX; } + int type; uint256 hash; }; +/** Convert a TX/WITNESS_TX/WTX CInv to a GenTxid. */ +GenTxid ToGenTxid(const CInv& inv); + #endif // BITCOIN_PROTOCOL_H diff --git a/src/psbt.cpp b/src/psbt.cpp index ef9781817a..3fb743e5db 100644 --- a/src/psbt.cpp +++ b/src/psbt.cpp @@ -35,14 +35,6 @@ bool PartiallySignedTransaction::Merge(const PartiallySignedTransaction& psbt) return true; } -bool PartiallySignedTransaction::IsSane() const -{ - for (PSBTInput input : inputs) { - if (!input.IsSane()) return false; - } - return true; -} - bool PartiallySignedTransaction::AddInput(const CTxIn& txin, PSBTInput& psbtin) { if (std::find(tx->vin.begin(), tx->vin.end(), txin) != tx->vin.end()) { @@ -144,8 +136,8 @@ void PSBTInput::Merge(const PSBTInput& input) { if (!non_witness_utxo && input.non_witness_utxo) non_witness_utxo = input.non_witness_utxo; if (witness_utxo.IsNull() && !input.witness_utxo.IsNull()) { + // TODO: For segwit v1, we will want to clear out the non-witness utxo when setting a witness one. For v0 and non-segwit, this is not safe witness_utxo = input.witness_utxo; - non_witness_utxo = nullptr; // Clear out any non-witness utxo when we set a witness one. } partial_sigs.insert(input.partial_sigs.begin(), input.partial_sigs.end()); @@ -158,18 +150,6 @@ void PSBTInput::Merge(const PSBTInput& input) if (final_script_witness.IsNull() && !input.final_script_witness.IsNull()) final_script_witness = input.final_script_witness; } -bool PSBTInput::IsSane() const -{ - // Cannot have both witness and non-witness utxos - if (!witness_utxo.IsNull() && non_witness_utxo) return false; - - // If we have a witness_script or a scriptWitness, we must also have a witness utxo - if (!witness_script.empty() && witness_utxo.IsNull()) return false; - if (!final_script_witness.IsNull() && witness_utxo.IsNull()) return false; - - return true; -} - void PSBTOutput::FillSignatureData(SignatureData& sigdata) const { if (!redeem_script.empty()) { @@ -214,6 +194,17 @@ bool PSBTInputSigned(const PSBTInput& input) return !input.final_script_sig.empty() || !input.final_script_witness.IsNull(); } +size_t CountPSBTUnsignedInputs(const PartiallySignedTransaction& psbt) { + size_t count = 0; + for (const auto& input : psbt.inputs) { + if (!PSBTInputSigned(input)) { + count++; + } + } + + return count; +} + void UpdatePSBTOutput(const SigningProvider& provider, PartiallySignedTransaction& psbt, int index) { const CTxOut& out = psbt.tx->vout.at(index); @@ -250,11 +241,6 @@ bool SignPSBTInput(const SigningProvider& provider, PartiallySignedTransaction& bool require_witness_sig = false; CTxOut utxo; - // Verify input sanity, which checks that at most one of witness or non-witness utxos is provided. - if (!input.IsSane()) { - return false; - } - if (input.non_witness_utxo) { // If we're taking our information from a non-witness UTXO, verify that it matches the prevout. COutPoint prevout = tx.vin[index].prevout; @@ -288,10 +274,11 @@ bool SignPSBTInput(const SigningProvider& provider, PartiallySignedTransaction& if (require_witness_sig && !sigdata.witness) return false; input.FromSignatureData(sigdata); - // If we have a witness signature, use the smaller witness UTXO. + // If we have a witness signature, put a witness UTXO. + // TODO: For segwit v1, we should remove the non_witness_utxo if (sigdata.witness) { input.witness_utxo = utxo; - input.non_witness_utxo = nullptr; + // input.non_witness_utxo = nullptr; } // Fill in the missing info @@ -345,10 +332,6 @@ TransactionError CombinePSBTs(PartiallySignedTransaction& out, const std::vector return TransactionError::PSBT_MISMATCH; } } - if (!out.IsSane()) { - return TransactionError::INVALID_PSBT; - } - return TransactionError::OK; } diff --git a/src/psbt.h b/src/psbt.h index af57994f3a..0951b76f83 100644 --- a/src/psbt.h +++ b/src/psbt.h @@ -41,7 +41,7 @@ static constexpr uint8_t PSBT_OUT_BIP32_DERIVATION = 0x02; static constexpr uint8_t PSBT_SEPARATOR = 0x00; // BIP 174 does not specify a maximum file size, but we set a limit anyway -// to prevent reading a stream indefinately and running out of memory. +// to prevent reading a stream indefinitely and running out of memory. const std::streamsize MAX_FILE_SIZE_PSBT = 100000000; // 100 MiB /** A structure for PSBTs which contain per-input information */ @@ -62,18 +62,17 @@ struct PSBTInput void FillSignatureData(SignatureData& sigdata) const; void FromSignatureData(const SignatureData& sigdata); void Merge(const PSBTInput& input); - bool IsSane() const; PSBTInput() {} template <typename Stream> inline void Serialize(Stream& s) const { // Write the utxo - // If there is a non-witness utxo, then don't add the witness one. if (non_witness_utxo) { SerializeToVector(s, PSBT_IN_NON_WITNESS_UTXO); OverrideStream<Stream> os(&s, s.GetType(), s.GetVersion() | SERIALIZE_TRANSACTION_NO_WITNESS); SerializeToVector(os, non_witness_utxo); - } else if (!witness_utxo.IsNull()) { + } + if (!witness_utxo.IsNull()) { SerializeToVector(s, PSBT_IN_WITNESS_UTXO); SerializeToVector(s, witness_utxo); } @@ -284,7 +283,6 @@ struct PSBTOutput void FillSignatureData(SignatureData& sigdata) const; void FromSignatureData(const SignatureData& sigdata); void Merge(const PSBTOutput& output); - bool IsSane() const; PSBTOutput() {} template <typename Stream> @@ -401,7 +399,6 @@ struct PartiallySignedTransaction /** Merge psbt into this. The two psbts must have the same underlying CTransaction (i.e. the * same actual Bitcoin transaction.) Returns true if the merge succeeded, false otherwise. */ NODISCARD bool Merge(const PartiallySignedTransaction& psbt); - bool IsSane() const; bool AddInput(const CTxIn& txin, PSBTInput& psbtin); bool AddOutput(const CTxOut& txout, const PSBTOutput& psbtout); PartiallySignedTransaction() {} @@ -551,10 +548,6 @@ struct PartiallySignedTransaction if (outputs.size() != tx->vout.size()) { throw std::ios_base::failure("Outputs provided does not match the number of outputs in transaction."); } - // Sanity check - if (!IsSane()) { - throw std::ios_base::failure("PSBT is not sane."); - } } template <typename Stream> @@ -579,6 +572,9 @@ bool PSBTInputSigned(const PSBTInput& input); /** Signs a PSBTInput, verifying that all provided data matches what is being signed. */ bool SignPSBTInput(const SigningProvider& provider, PartiallySignedTransaction& psbt, int index, int sighash = SIGHASH_ALL, SignatureData* out_sigdata = nullptr, bool use_dummy = false); +/** Counts the unsigned inputs of a PSBT. */ +size_t CountPSBTUnsignedInputs(const PartiallySignedTransaction& psbt); + /** Updates a PSBTOutput with information from provider. * * This fills in the redeem_script, witness_script, and hd_keypaths where possible. diff --git a/src/pubkey.h b/src/pubkey.h index 261842b7f7..fcbc7e8416 100644 --- a/src/pubkey.h +++ b/src/pubkey.h @@ -142,6 +142,9 @@ public: unsigned int len = ::ReadCompactSize(s); if (len <= SIZE) { s.read((char*)vch, len); + if (len != size()) { + Invalidate(); + } } else { // invalid pubkey, skip available data char dummy; @@ -154,13 +157,13 @@ public: //! Get the KeyID of this public key (hash of its serialization) CKeyID GetID() const { - return CKeyID(Hash160(vch, vch + size())); + return CKeyID(Hash160(MakeSpan(vch).first(size()))); } //! Get the 256-bit hash of this public key. uint256 GetHash() const { - return Hash(vch, vch + size()); + return Hash(MakeSpan(vch).first(size())); } /* diff --git a/src/qt/bitcoin.cpp b/src/qt/bitcoin.cpp index ad74ca3a02..f53fcc41f3 100644 --- a/src/qt/bitcoin.cpp +++ b/src/qt/bitcoin.cpp @@ -29,11 +29,12 @@ #include <interfaces/handler.h> #include <interfaces/node.h> +#include <node/context.h> #include <noui.h> -#include <ui_interface.h> #include <uint256.h> #include <util/system.h> #include <util/threadnames.h> +#include <util/translation.h> #include <validation.h> #include <memory> @@ -65,6 +66,24 @@ Q_DECLARE_METATYPE(CAmount) Q_DECLARE_METATYPE(SynchronizationState) Q_DECLARE_METATYPE(uint256) +static void RegisterMetaTypes() +{ + // Register meta types used for QMetaObject::invokeMethod and Qt::QueuedConnection + qRegisterMetaType<bool*>(); + qRegisterMetaType<SynchronizationState>(); + #ifdef ENABLE_WALLET + qRegisterMetaType<WalletModel*>(); + #endif + // Register typedefs (see http://qt-project.org/doc/qt-5/qmetatype.html#qRegisterMetaType) + // IMPORTANT: if CAmount is no longer a typedef use the normal variant above (see https://doc.qt.io/qt-5/qmetatype.html#qRegisterMetaType-1) + qRegisterMetaType<CAmount>("CAmount"); + qRegisterMetaType<size_t>("size_t"); + + qRegisterMetaType<std::function<void()>>("std::function<void()>"); + qRegisterMetaType<QMessageBox::Icon>("QMessageBox::Icon"); + qRegisterMetaType<interfaces::BlockAndHeaderTipInfo>("interfaces::BlockAndHeaderTipInfo"); +} + static QString GetLangTerritory() { QSettings settings; @@ -137,7 +156,7 @@ BitcoinCore::BitcoinCore(interfaces::Node& node) : void BitcoinCore::handleRunawayException(const std::exception *e) { PrintExceptionContinue(e, "Runaway exception"); - Q_EMIT runawayException(QString::fromStdString(m_node.getWarnings())); + Q_EMIT runawayException(QString::fromStdString(m_node.getWarnings().translated)); } void BitcoinCore::initialize() @@ -146,8 +165,9 @@ void BitcoinCore::initialize() { qDebug() << __func__ << ": Running initialization in thread"; util::ThreadRename("qt-init"); - bool rv = m_node.appInitMain(); - Q_EMIT initializeResult(rv); + interfaces::BlockAndHeaderTipInfo tip_info; + bool rv = m_node.appInitMain(&tip_info); + Q_EMIT initializeResult(rv, tip_info); } catch (const std::exception& e) { handleRunawayException(&e); } catch (...) { @@ -184,6 +204,7 @@ BitcoinApplication::BitcoinApplication(interfaces::Node& node): returnValue(0), platformStyle(nullptr) { + RegisterMetaTypes(); setQuitOnLastWindowClosed(false); } @@ -323,7 +344,7 @@ void BitcoinApplication::requestShutdown() Q_EMIT requestedShutdown(); } -void BitcoinApplication::initializeResult(bool success) +void BitcoinApplication::initializeResult(bool success, interfaces::BlockAndHeaderTipInfo tip_info) { qDebug() << __func__ << ": Initialization result: " << success; // Set exit result. @@ -333,7 +354,7 @@ void BitcoinApplication::initializeResult(bool success) // Log this only after AppInitMain finishes, as then logging setup is guaranteed complete qInfo() << "Platform customization:" << platformStyle->getName(); clientModel = new ClientModel(m_node, optionsModel); - window->setClientModel(clientModel); + window->setClientModel(clientModel, &tip_info); #ifdef ENABLE_WALLET if (WalletModel::isWalletEnabled()) { m_wallet_controller = new WalletController(*clientModel, platformStyle, this); @@ -381,7 +402,7 @@ void BitcoinApplication::shutdownResult() void BitcoinApplication::handleRunawayException(const QString &message) { - QMessageBox::critical(nullptr, "Runaway exception", BitcoinGUI::tr("A fatal error occurred. %1 can no longer continue safely and will quit.").arg(PACKAGE_NAME) + QString("\n\n") + message); + QMessageBox::critical(nullptr, "Runaway exception", BitcoinGUI::tr("A fatal error occurred. %1 can no longer continue safely and will quit.").arg(PACKAGE_NAME) + QString("<br><br>") + message); ::exit(EXIT_FAILURE); } @@ -393,14 +414,14 @@ WId BitcoinApplication::getMainWinId() const return window->winId(); } -static void SetupUIArgs() +static void SetupUIArgs(ArgsManager& argsman) { - gArgs.AddArg("-choosedatadir", strprintf("Choose data directory on startup (default: %u)", DEFAULT_CHOOSE_DATADIR), ArgsManager::ALLOW_ANY, OptionsCategory::GUI); - gArgs.AddArg("-lang=<lang>", "Set language, for example \"de_DE\" (default: system locale)", ArgsManager::ALLOW_ANY, OptionsCategory::GUI); - gArgs.AddArg("-min", "Start minimized", ArgsManager::ALLOW_ANY, OptionsCategory::GUI); - gArgs.AddArg("-resetguisettings", "Reset all settings changed in the GUI", ArgsManager::ALLOW_ANY, OptionsCategory::GUI); - gArgs.AddArg("-splash", strprintf("Show splash screen on startup (default: %u)", DEFAULT_SPLASHSCREEN), ArgsManager::ALLOW_ANY, OptionsCategory::GUI); - gArgs.AddArg("-uiplatform", strprintf("Select platform to customize UI for (one of windows, macosx, other; default: %s)", BitcoinGUI::DEFAULT_UIPLATFORM), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::GUI); + argsman.AddArg("-choosedatadir", strprintf("Choose data directory on startup (default: %u)", DEFAULT_CHOOSE_DATADIR), ArgsManager::ALLOW_ANY, OptionsCategory::GUI); + argsman.AddArg("-lang=<lang>", "Set language, for example \"de_DE\" (default: system locale)", ArgsManager::ALLOW_ANY, OptionsCategory::GUI); + argsman.AddArg("-min", "Start minimized", ArgsManager::ALLOW_ANY, OptionsCategory::GUI); + argsman.AddArg("-resetguisettings", "Reset all settings changed in the GUI", ArgsManager::ALLOW_ANY, OptionsCategory::GUI); + argsman.AddArg("-splash", strprintf("Show splash screen on startup (default: %u)", DEFAULT_SPLASHSCREEN), ArgsManager::ALLOW_ANY, OptionsCategory::GUI); + argsman.AddArg("-uiplatform", strprintf("Select platform to customize UI for (one of windows, macosx, other; default: %s)", BitcoinGUI::DEFAULT_UIPLATFORM), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::GUI); } int GuiMain(int argc, char* argv[]) @@ -412,7 +433,8 @@ int GuiMain(int argc, char* argv[]) SetupEnvironment(); util::ThreadSetInternalName("main"); - std::unique_ptr<interfaces::Node> node = interfaces::MakeNode(); + NodeContext node_context; + std::unique_ptr<interfaces::Node> node = interfaces::MakeNode(&node_context); // Subscribe to global signals from core std::unique_ptr<interfaces::Handler> handler_message_box = node->handleMessageBox(noui_ThreadSafeMessageBox); @@ -433,27 +455,13 @@ int GuiMain(int argc, char* argv[]) BitcoinApplication app(*node); - // Register meta types used for QMetaObject::invokeMethod and Qt::QueuedConnection - qRegisterMetaType<bool*>(); - qRegisterMetaType<SynchronizationState>(); -#ifdef ENABLE_WALLET - qRegisterMetaType<WalletModel*>(); -#endif - // Register typedefs (see http://qt-project.org/doc/qt-5/qmetatype.html#qRegisterMetaType) - // IMPORTANT: if CAmount is no longer a typedef use the normal variant above (see https://doc.qt.io/qt-5/qmetatype.html#qRegisterMetaType-1) - qRegisterMetaType<CAmount>("CAmount"); - qRegisterMetaType<size_t>("size_t"); - - qRegisterMetaType<std::function<void()>>("std::function<void()>"); - qRegisterMetaType<QMessageBox::Icon>("QMessageBox::Icon"); - /// 2. Parse command-line options. We do this after qt in order to show an error if there are problems parsing these // Command-line options take precedence: node->setupServerArgs(); - SetupUIArgs(); + SetupUIArgs(gArgs); std::string error; if (!node->parseParameters(argc, argv, error)) { - node->initError(strprintf("Error parsing command line arguments: %s\n", error)); + node->initError(strprintf(Untranslated("Error parsing command line arguments: %s\n"), error)); // Create a message box, because the gui has neither been created nor has subscribed to core signals QMessageBox::critical(nullptr, PACKAGE_NAME, // message can not be translated because translations have not been initialized @@ -494,13 +502,13 @@ int GuiMain(int argc, char* argv[]) /// 6. Determine availability of data directory and parse bitcoin.conf /// - Do not call GetDataDir(true) before this step finishes if (!CheckDataDirOption()) { - node->initError(strprintf("Specified data directory \"%s\" does not exist.\n", gArgs.GetArg("-datadir", ""))); + node->initError(strprintf(Untranslated("Specified data directory \"%s\" does not exist.\n"), gArgs.GetArg("-datadir", ""))); QMessageBox::critical(nullptr, PACKAGE_NAME, QObject::tr("Error: Specified data directory \"%1\" does not exist.").arg(QString::fromStdString(gArgs.GetArg("-datadir", "")))); return EXIT_FAILURE; } if (!node->readConfigFiles(error)) { - node->initError(strprintf("Error reading configuration file: %s\n", error)); + node->initError(strprintf(Untranslated("Error reading configuration file: %s\n"), error)); QMessageBox::critical(nullptr, PACKAGE_NAME, QObject::tr("Error: Cannot parse configuration file: %1.").arg(QString::fromStdString(error))); return EXIT_FAILURE; @@ -516,7 +524,7 @@ int GuiMain(int argc, char* argv[]) try { node->selectParams(gArgs.GetChainName()); } catch(std::exception &e) { - node->initError(strprintf("%s\n", e.what())); + node->initError(Untranslated(strprintf("%s\n", e.what()))); QMessageBox::critical(nullptr, PACKAGE_NAME, QObject::tr("Error: %1").arg(e.what())); return EXIT_FAILURE; } @@ -524,6 +532,11 @@ int GuiMain(int argc, char* argv[]) // Parse URIs on command line -- this can affect Params() PaymentServer::ipcParseCommandLine(*node, argc, argv); #endif + if (!node->initSettings(error)) { + node->initError(Untranslated(error)); + QMessageBox::critical(nullptr, PACKAGE_NAME, QObject::tr("Error initializing settings: %1").arg(QString::fromStdString(error))); + return EXIT_FAILURE; + } QScopedPointer<const NetworkStyle> networkStyle(NetworkStyle::instantiate(Params().NetworkIDString())); assert(!networkStyle.isNull()); @@ -552,6 +565,8 @@ int GuiMain(int argc, char* argv[]) /// 9. Main GUI initialization // Install global event filter that makes sure that long tooltips can be word-wrapped app.installEventFilter(new GUIUtil::ToolTipToRichTextFilter(TOOLTIP_WRAP_THRESHOLD, &app)); + // Install global event filter that makes sure that out-of-focus labels do not contain text cursor. + app.installEventFilter(new GUIUtil::LabelOutOfFocusEventFilter(&app)); #if defined(Q_OS_WIN) // Install global event filter for processing Windows session related Windows messages (WM_QUERYENDSESSION and WM_ENDSESSION) qApp->installNativeEventFilter(new WinShutdownMonitor()); @@ -594,10 +609,10 @@ int GuiMain(int argc, char* argv[]) } } catch (const std::exception& e) { PrintExceptionContinue(&e, "Runaway exception"); - app.handleRunawayException(QString::fromStdString(node->getWarnings())); + app.handleRunawayException(QString::fromStdString(node->getWarnings().translated)); } catch (...) { PrintExceptionContinue(nullptr, "Runaway exception"); - app.handleRunawayException(QString::fromStdString(node->getWarnings())); + app.handleRunawayException(QString::fromStdString(node->getWarnings().translated)); } return rv; } diff --git a/src/qt/bitcoin.h b/src/qt/bitcoin.h index 077a37fde5..20c6dfc047 100644 --- a/src/qt/bitcoin.h +++ b/src/qt/bitcoin.h @@ -12,6 +12,8 @@ #include <QApplication> #include <memory> +#include <interfaces/node.h> + class BitcoinGUI; class ClientModel; class NetworkStyle; @@ -21,10 +23,6 @@ class PlatformStyle; class WalletController; class WalletModel; -namespace interfaces { -class Handler; -class Node; -} // namespace interfaces /** Class encapsulating Bitcoin Core startup and shutdown. * Allows running startup and shutdown in a different thread from the UI thread. @@ -40,7 +38,7 @@ public Q_SLOTS: void shutdown(); Q_SIGNALS: - void initializeResult(bool success); + void initializeResult(bool success, interfaces::BlockAndHeaderTipInfo tip_info); void shutdownResult(); void runawayException(const QString &message); @@ -91,7 +89,7 @@ public: void setupPlatformStyle(); public Q_SLOTS: - void initializeResult(bool success); + void initializeResult(bool success, interfaces::BlockAndHeaderTipInfo tip_info); void shutdownResult(); /// Handle runaway exceptions. Shows a message box with the problem and quits the program. void handleRunawayException(const QString &message); diff --git a/src/qt/bitcoin.qrc b/src/qt/bitcoin.qrc index 037b23e4b2..7115459808 100644 --- a/src/qt/bitcoin.qrc +++ b/src/qt/bitcoin.qrc @@ -45,42 +45,42 @@ <file alias="network_disabled">res/icons/network_disabled.png</file> <file alias="proxy">res/icons/proxy.png</file> </qresource> - <qresource prefix="/movies"> - <file alias="spinner-000">res/movies/spinner-000.png</file> - <file alias="spinner-001">res/movies/spinner-001.png</file> - <file alias="spinner-002">res/movies/spinner-002.png</file> - <file alias="spinner-003">res/movies/spinner-003.png</file> - <file alias="spinner-004">res/movies/spinner-004.png</file> - <file alias="spinner-005">res/movies/spinner-005.png</file> - <file alias="spinner-006">res/movies/spinner-006.png</file> - <file alias="spinner-007">res/movies/spinner-007.png</file> - <file alias="spinner-008">res/movies/spinner-008.png</file> - <file alias="spinner-009">res/movies/spinner-009.png</file> - <file alias="spinner-010">res/movies/spinner-010.png</file> - <file alias="spinner-011">res/movies/spinner-011.png</file> - <file alias="spinner-012">res/movies/spinner-012.png</file> - <file alias="spinner-013">res/movies/spinner-013.png</file> - <file alias="spinner-014">res/movies/spinner-014.png</file> - <file alias="spinner-015">res/movies/spinner-015.png</file> - <file alias="spinner-016">res/movies/spinner-016.png</file> - <file alias="spinner-017">res/movies/spinner-017.png</file> - <file alias="spinner-018">res/movies/spinner-018.png</file> - <file alias="spinner-019">res/movies/spinner-019.png</file> - <file alias="spinner-020">res/movies/spinner-020.png</file> - <file alias="spinner-021">res/movies/spinner-021.png</file> - <file alias="spinner-022">res/movies/spinner-022.png</file> - <file alias="spinner-023">res/movies/spinner-023.png</file> - <file alias="spinner-024">res/movies/spinner-024.png</file> - <file alias="spinner-025">res/movies/spinner-025.png</file> - <file alias="spinner-026">res/movies/spinner-026.png</file> - <file alias="spinner-027">res/movies/spinner-027.png</file> - <file alias="spinner-028">res/movies/spinner-028.png</file> - <file alias="spinner-029">res/movies/spinner-029.png</file> - <file alias="spinner-030">res/movies/spinner-030.png</file> - <file alias="spinner-031">res/movies/spinner-031.png</file> - <file alias="spinner-032">res/movies/spinner-032.png</file> - <file alias="spinner-033">res/movies/spinner-033.png</file> - <file alias="spinner-034">res/movies/spinner-034.png</file> - <file alias="spinner-035">res/movies/spinner-035.png</file> + <qresource prefix="/animation"> + <file alias="spinner-000">res/animation/spinner-000.png</file> + <file alias="spinner-001">res/animation/spinner-001.png</file> + <file alias="spinner-002">res/animation/spinner-002.png</file> + <file alias="spinner-003">res/animation/spinner-003.png</file> + <file alias="spinner-004">res/animation/spinner-004.png</file> + <file alias="spinner-005">res/animation/spinner-005.png</file> + <file alias="spinner-006">res/animation/spinner-006.png</file> + <file alias="spinner-007">res/animation/spinner-007.png</file> + <file alias="spinner-008">res/animation/spinner-008.png</file> + <file alias="spinner-009">res/animation/spinner-009.png</file> + <file alias="spinner-010">res/animation/spinner-010.png</file> + <file alias="spinner-011">res/animation/spinner-011.png</file> + <file alias="spinner-012">res/animation/spinner-012.png</file> + <file alias="spinner-013">res/animation/spinner-013.png</file> + <file alias="spinner-014">res/animation/spinner-014.png</file> + <file alias="spinner-015">res/animation/spinner-015.png</file> + <file alias="spinner-016">res/animation/spinner-016.png</file> + <file alias="spinner-017">res/animation/spinner-017.png</file> + <file alias="spinner-018">res/animation/spinner-018.png</file> + <file alias="spinner-019">res/animation/spinner-019.png</file> + <file alias="spinner-020">res/animation/spinner-020.png</file> + <file alias="spinner-021">res/animation/spinner-021.png</file> + <file alias="spinner-022">res/animation/spinner-022.png</file> + <file alias="spinner-023">res/animation/spinner-023.png</file> + <file alias="spinner-024">res/animation/spinner-024.png</file> + <file alias="spinner-025">res/animation/spinner-025.png</file> + <file alias="spinner-026">res/animation/spinner-026.png</file> + <file alias="spinner-027">res/animation/spinner-027.png</file> + <file alias="spinner-028">res/animation/spinner-028.png</file> + <file alias="spinner-029">res/animation/spinner-029.png</file> + <file alias="spinner-030">res/animation/spinner-030.png</file> + <file alias="spinner-031">res/animation/spinner-031.png</file> + <file alias="spinner-032">res/animation/spinner-032.png</file> + <file alias="spinner-033">res/animation/spinner-033.png</file> + <file alias="spinner-034">res/animation/spinner-034.png</file> + <file alias="spinner-035">res/animation/spinner-035.png</file> </qresource> </RCC> diff --git a/src/qt/bitcoinamountfield.cpp b/src/qt/bitcoinamountfield.cpp index 4c57f1e352..a953c991bc 100644 --- a/src/qt/bitcoinamountfield.cpp +++ b/src/qt/bitcoinamountfield.cpp @@ -56,7 +56,7 @@ public: if (valid) { val = qBound(m_min_amount, val, m_max_amount); - input = BitcoinUnits::format(currentUnit, val, false, BitcoinUnits::separatorAlways); + input = BitcoinUnits::format(currentUnit, val, false, BitcoinUnits::SeparatorStyle::ALWAYS); lineEdit()->setText(input); } } @@ -68,7 +68,7 @@ public: void setValue(const CAmount& value) { - lineEdit()->setText(BitcoinUnits::format(currentUnit, value, false, BitcoinUnits::separatorAlways)); + lineEdit()->setText(BitcoinUnits::format(currentUnit, value, false, BitcoinUnits::SeparatorStyle::ALWAYS)); Q_EMIT valueChanged(); } @@ -102,7 +102,7 @@ public: CAmount val = value(&valid); currentUnit = unit; - lineEdit()->setPlaceholderText(BitcoinUnits::format(currentUnit, m_min_amount, false, BitcoinUnits::separatorAlways)); + lineEdit()->setPlaceholderText(BitcoinUnits::format(currentUnit, m_min_amount, false, BitcoinUnits::SeparatorStyle::ALWAYS)); if(valid) setValue(val); else @@ -122,7 +122,7 @@ public: const QFontMetrics fm(fontMetrics()); int h = lineEdit()->minimumSizeHint().height(); - int w = GUIUtil::TextWidth(fm, BitcoinUnits::format(BitcoinUnits::BTC, BitcoinUnits::maxMoney(), false, BitcoinUnits::separatorAlways)); + int w = GUIUtil::TextWidth(fm, BitcoinUnits::format(BitcoinUnits::BTC, BitcoinUnits::maxMoney(), false, BitcoinUnits::SeparatorStyle::ALWAYS)); w += 2; // cursor blinking space QStyleOptionSpinBox opt; diff --git a/src/qt/bitcoingui.cpp b/src/qt/bitcoingui.cpp index 6192013e5f..56adbf249a 100644 --- a/src/qt/bitcoingui.cpp +++ b/src/qt/bitcoingui.cpp @@ -34,7 +34,7 @@ #include <chainparams.h> #include <interfaces/handler.h> #include <interfaces/node.h> -#include <ui_interface.h> +#include <node/ui_interface.h> #include <util/system.h> #include <util/translation.h> #include <validation.h> @@ -112,6 +112,8 @@ BitcoinGUI::BitcoinGUI(interfaces::Node& node, const PlatformStyle *_platformSty Q_EMIT consoleShown(rpcConsole); } + modalOverlay = new ModalOverlay(enableWallet, this->centralWidget()); + // Accept D&D of URIs setAcceptDrops(true); @@ -201,7 +203,6 @@ BitcoinGUI::BitcoinGUI(interfaces::Node& node, const PlatformStyle *_platformSty openOptionsDialogWithTab(OptionsDialog::TAB_NETWORK); }); - modalOverlay = new ModalOverlay(enableWallet, this->centralWidget()); connect(labelBlocksIcon, &GUIUtil::ClickableLabel::clicked, this, &BitcoinGUI::showModalOverlay); connect(progressBar, &GUIUtil::ClickableProgressBar::clicked, this, &BitcoinGUI::showModalOverlay); #ifdef ENABLE_WALLET @@ -238,6 +239,7 @@ BitcoinGUI::~BitcoinGUI() void BitcoinGUI::createActions() { QActionGroup *tabGroup = new QActionGroup(this); + connect(modalOverlay, &ModalOverlay::triggered, tabGroup, &QActionGroup::setEnabled); overviewAction = new QAction(platformStyle->SingleColorIcon(":/icons/overview"), tr("&Overview"), this); overviewAction->setStatusTip(tr("Show general overview of wallet")); @@ -321,8 +323,10 @@ void BitcoinGUI::createActions() signMessageAction->setStatusTip(tr("Sign messages with your Bitcoin addresses to prove you own them")); verifyMessageAction = new QAction(tr("&Verify message..."), this); verifyMessageAction->setStatusTip(tr("Verify messages to ensure they were signed with specified Bitcoin addresses")); - m_load_psbt_action = new QAction(tr("Load PSBT..."), this); + m_load_psbt_action = new QAction(tr("&Load PSBT from file..."), this); m_load_psbt_action->setStatusTip(tr("Load Partially Signed Bitcoin Transaction")); + m_load_psbt_clipboard_action = new QAction(tr("Load PSBT from clipboard..."), this); + m_load_psbt_clipboard_action->setStatusTip(tr("Load Partially Signed Bitcoin Transaction from clipboard")); openRPCConsoleAction = new QAction(tr("Node window"), this); openRPCConsoleAction->setStatusTip(tr("Open node debugging and diagnostic console")); @@ -350,6 +354,9 @@ void BitcoinGUI::createActions() m_create_wallet_action->setEnabled(false); m_create_wallet_action->setStatusTip(tr("Create a new wallet")); + m_close_all_wallets_action = new QAction(tr("Close All Wallets..."), this); + m_close_all_wallets_action->setStatusTip(tr("Close all wallets")); + showHelpMessageAction = new QAction(tr("&Command-line options"), this); showHelpMessageAction->setMenuRole(QAction::NoRole); showHelpMessageAction->setStatusTip(tr("Show the %1 help message to get a list with possible Bitcoin command-line options").arg(PACKAGE_NAME)); @@ -378,6 +385,7 @@ void BitcoinGUI::createActions() connect(signMessageAction, &QAction::triggered, [this]{ showNormalIfMinimized(); }); connect(signMessageAction, &QAction::triggered, [this]{ gotoSignMessageTab(); }); connect(m_load_psbt_action, &QAction::triggered, [this]{ gotoLoadPSBT(); }); + connect(m_load_psbt_clipboard_action, &QAction::triggered, [this]{ gotoLoadPSBT(true); }); connect(verifyMessageAction, &QAction::triggered, [this]{ showNormalIfMinimized(); }); connect(verifyMessageAction, &QAction::triggered, [this]{ gotoVerifyMessageTab(); }); connect(usedSendingAddressesAction, &QAction::triggered, walletFrame, &WalletFrame::usedSendingAddresses); @@ -421,7 +429,9 @@ void BitcoinGUI::createActions() connect(activity, &CreateWalletActivity::finished, activity, &QObject::deleteLater); activity->create(); }); - + connect(m_close_all_wallets_action, &QAction::triggered, [this] { + m_wallet_controller->closeAllWallets(this); + }); connect(m_mask_values_action, &QAction::toggled, this, &BitcoinGUI::setPrivacy); } #endif // ENABLE_WALLET @@ -447,12 +457,14 @@ void BitcoinGUI::createMenuBar() file->addAction(m_create_wallet_action); file->addAction(m_open_wallet_action); file->addAction(m_close_wallet_action); + file->addAction(m_close_all_wallets_action); file->addSeparator(); file->addAction(openAction); file->addAction(backupWalletAction); file->addAction(signMessageAction); file->addAction(verifyMessageAction); file->addAction(m_load_psbt_action); + file->addAction(m_load_psbt_clipboard_action); file->addSeparator(); } file->addAction(quitAction); @@ -562,7 +574,7 @@ void BitcoinGUI::createToolBars() } } -void BitcoinGUI::setClientModel(ClientModel *_clientModel) +void BitcoinGUI::setClientModel(ClientModel *_clientModel, interfaces::BlockAndHeaderTipInfo* tip_info) { this->clientModel = _clientModel; if(_clientModel) @@ -576,8 +588,8 @@ void BitcoinGUI::setClientModel(ClientModel *_clientModel) connect(_clientModel, &ClientModel::numConnectionsChanged, this, &BitcoinGUI::setNumConnections); connect(_clientModel, &ClientModel::networkActiveChanged, this, &BitcoinGUI::setNetworkActive); - modalOverlay->setKnownBestHeight(_clientModel->getHeaderTipHeight(), QDateTime::fromTime_t(_clientModel->getHeaderTipTime())); - setNumBlocks(m_node.getNumBlocks(), QDateTime::fromTime_t(m_node.getLastBlockTime()), m_node.getVerificationProgress(), false, SynchronizationState::INIT_DOWNLOAD); + modalOverlay->setKnownBestHeight(tip_info->header_height, QDateTime::fromTime_t(tip_info->header_time)); + setNumBlocks(tip_info->block_height, QDateTime::fromTime_t(tip_info->block_time), tip_info->verification_progress, false, SynchronizationState::INIT_DOWNLOAD); connect(_clientModel, &ClientModel::numBlocksChanged, this, &BitcoinGUI::setNumBlocks); // Receive and report messages from client model @@ -588,7 +600,7 @@ void BitcoinGUI::setClientModel(ClientModel *_clientModel) // Show progress dialog connect(_clientModel, &ClientModel::showProgress, this, &BitcoinGUI::showProgress); - rpcConsole->setClientModel(_clientModel); + rpcConsole->setClientModel(_clientModel, tip_info->block_height, tip_info->block_time, tip_info->verification_progress); updateProxyIcon(); @@ -673,6 +685,7 @@ void BitcoinGUI::removeWallet(WalletModel* walletModel) m_wallet_selector->removeItem(index); if (m_wallet_selector->count() == 0) { setWalletActionsEnabled(false); + overviewAction->setChecked(true); } else if (m_wallet_selector->count() == 1) { m_wallet_selector_label_action->setVisible(false); m_wallet_selector_action->setVisible(false); @@ -727,6 +740,7 @@ void BitcoinGUI::setWalletActionsEnabled(bool enabled) usedReceivingAddressesAction->setEnabled(enabled); openAction->setEnabled(enabled); m_close_wallet_action->setEnabled(enabled); + m_close_all_wallets_action->setEnabled(enabled); } void BitcoinGUI::createTrayIcon() @@ -871,9 +885,9 @@ void BitcoinGUI::gotoVerifyMessageTab(QString addr) { if (walletFrame) walletFrame->gotoVerifyMessageTab(addr); } -void BitcoinGUI::gotoLoadPSBT() +void BitcoinGUI::gotoLoadPSBT(bool from_clipboard) { - if (walletFrame) walletFrame->gotoLoadPSBT(); + if (walletFrame) walletFrame->gotoLoadPSBT(from_clipboard); } #endif // ENABLE_WALLET @@ -1026,7 +1040,7 @@ void BitcoinGUI::setNumBlocks(int count, const QDateTime& blockDate, double nVer if(count != prevBlocks) { labelBlocksIcon->setPixmap(platformStyle->SingleColorIcon(QString( - ":/movies/spinner-%1").arg(spinnerFrame, 3, 10, QChar('0'))) + ":/animation/spinner-%1").arg(spinnerFrame, 3, 10, QChar('0'))) .pixmap(STATUSBAR_ICONSIZE, STATUSBAR_ICONSIZE)); spinnerFrame = (spinnerFrame + 1) % SPINNER_FRAMES; } @@ -1062,9 +1076,6 @@ void BitcoinGUI::message(const QString& title, QString message, unsigned int sty int nMBoxIcon = QMessageBox::Information; int nNotifyIcon = Notificator::Information; - bool prefix = !(style & CClientUIInterface::MSG_NOPREFIX); - style &= ~CClientUIInterface::MSG_NOPREFIX; - QString msgType; if (!title.isEmpty()) { msgType = title; @@ -1072,11 +1083,11 @@ void BitcoinGUI::message(const QString& title, QString message, unsigned int sty switch (style) { case CClientUIInterface::MSG_ERROR: msgType = tr("Error"); - if (prefix) message = tr("Error: %1").arg(message); + message = tr("Error: %1").arg(message); break; case CClientUIInterface::MSG_WARNING: msgType = tr("Warning"); - if (prefix) message = tr("Warning: %1").arg(message); + message = tr("Warning: %1").arg(message); break; case CClientUIInterface::MSG_INFORMATION: msgType = tr("Information"); diff --git a/src/qt/bitcoingui.h b/src/qt/bitcoingui.h index c0198dd168..4c55f28693 100644 --- a/src/qt/bitcoingui.h +++ b/src/qt/bitcoingui.h @@ -43,6 +43,7 @@ enum class SynchronizationState; namespace interfaces { class Handler; class Node; +struct BlockAndHeaderTipInfo; } QT_BEGIN_NAMESPACE @@ -75,7 +76,7 @@ public: /** Set the client model. The client model represents the part of the core that communicates with the P2P network, and is wallet-agnostic. */ - void setClientModel(ClientModel *clientModel); + void setClientModel(ClientModel *clientModel = nullptr, interfaces::BlockAndHeaderTipInfo* tip_info = nullptr); #ifdef ENABLE_WALLET void setWalletController(WalletController* wallet_controller); #endif @@ -139,6 +140,7 @@ private: QAction* signMessageAction = nullptr; QAction* verifyMessageAction = nullptr; QAction* m_load_psbt_action = nullptr; + QAction* m_load_psbt_clipboard_action = nullptr; QAction* aboutAction = nullptr; QAction* receiveCoinsAction = nullptr; QAction* receiveCoinsMenuAction = nullptr; @@ -155,6 +157,7 @@ private: QAction* m_open_wallet_action{nullptr}; QMenu* m_open_wallet_menu{nullptr}; QAction* m_close_wallet_action{nullptr}; + QAction* m_close_all_wallets_action{nullptr}; QAction* m_wallet_selector_label_action = nullptr; QAction* m_wallet_selector_action = nullptr; QAction* m_mask_values_action{nullptr}; @@ -277,8 +280,8 @@ public Q_SLOTS: void gotoSignMessageTab(QString addr = ""); /** Show Sign/Verify Message dialog and switch to verify message tab */ void gotoVerifyMessageTab(QString addr = ""); - /** Show load Partially Signed Bitcoin Transaction dialog */ - void gotoLoadPSBT(); + /** Load Partially Signed Bitcoin Transaction from file or clipboard */ + void gotoLoadPSBT(bool from_clipboard = false); /** Show open dialog */ void openClicked(); diff --git a/src/qt/bitcoinunits.cpp b/src/qt/bitcoinunits.cpp index 318a6dcbfd..fd55c547fc 100644 --- a/src/qt/bitcoinunits.cpp +++ b/src/qt/bitcoinunits.cpp @@ -114,7 +114,7 @@ QString BitcoinUnits::format(int unit, const CAmount& nIn, bool fPlus, Separator // confused with the decimal marker. QChar thin_sp(THIN_SP_CP); int q_size = quotient_str.size(); - if (separators == separatorAlways || (separators == separatorStandard && q_size > 4)) + if (separators == SeparatorStyle::ALWAYS || (separators == SeparatorStyle::STANDARD && q_size > 4)) for (int i = 3; i < q_size; i += 3) quotient_str.insert(q_size - i, thin_sp); diff --git a/src/qt/bitcoinunits.h b/src/qt/bitcoinunits.h index dac5484393..e22ba0a938 100644 --- a/src/qt/bitcoinunits.h +++ b/src/qt/bitcoinunits.h @@ -46,11 +46,11 @@ public: SAT }; - enum SeparatorStyle + enum class SeparatorStyle { - separatorNever, - separatorStandard, - separatorAlways + NEVER, + STANDARD, + ALWAYS }; //! @name Static API @@ -72,11 +72,11 @@ public: //! Number of decimals left static int decimals(int unit); //! Format as string - static QString format(int unit, const CAmount& amount, bool plussign = false, SeparatorStyle separators = separatorStandard, bool justify = false); + static QString format(int unit, const CAmount& amount, bool plussign = false, SeparatorStyle separators = SeparatorStyle::STANDARD, bool justify = false); //! Format as string (with unit) - static QString formatWithUnit(int unit, const CAmount& amount, bool plussign=false, SeparatorStyle separators=separatorStandard); + static QString formatWithUnit(int unit, const CAmount& amount, bool plussign=false, SeparatorStyle separators=SeparatorStyle::STANDARD); //! Format as HTML string (with unit) - static QString formatHtmlWithUnit(int unit, const CAmount& amount, bool plussign=false, SeparatorStyle separators=separatorStandard); + static QString formatHtmlWithUnit(int unit, const CAmount& amount, bool plussign=false, SeparatorStyle separators=SeparatorStyle::STANDARD); //! Format as string (with unit) of fixed length to preserve privacy, if it is set. static QString formatWithPrivacy(int unit, const CAmount& amount, SeparatorStyle separators, bool privacy); //! Parse string to coin amount diff --git a/src/qt/clientmodel.cpp b/src/qt/clientmodel.cpp index f15921c5bc..7822d4c5f3 100644 --- a/src/qt/clientmodel.cpp +++ b/src/qt/clientmodel.cpp @@ -116,9 +116,23 @@ int ClientModel::getNumBlocks() const uint256 ClientModel::getBestBlockHash() { + uint256 tip{WITH_LOCK(m_cached_tip_mutex, return m_cached_tip_blocks)}; + + if (!tip.IsNull()) { + return tip; + } + + // Lock order must be: first `cs_main`, then `m_cached_tip_mutex`. + // The following will lock `cs_main` (and release it), so we must not + // own `m_cached_tip_mutex` here. + tip = m_node.getBestBlockHash(); + LOCK(m_cached_tip_mutex); + // We checked that `m_cached_tip_blocks` is not null above, but then we + // released the mutex `m_cached_tip_mutex`, so it could have changed in the + // meantime. Thus, check again. if (m_cached_tip_blocks.IsNull()) { - m_cached_tip_blocks = m_node.getBestBlockHash(); + m_cached_tip_blocks = tip; } return m_cached_tip_blocks; } @@ -152,7 +166,7 @@ enum BlockSource ClientModel::getBlockSource() const QString ClientModel::getStatusBarWarnings() const { - return QString::fromStdString(m_node.getWarnings()); + return QString::fromStdString(m_node.getWarnings().translated); } OptionsModel *ClientModel::getOptionsModel() diff --git a/src/qt/coincontroldialog.cpp b/src/qt/coincontroldialog.cpp index db77c17df0..7c72858501 100644 --- a/src/qt/coincontroldialog.cpp +++ b/src/qt/coincontroldialog.cpp @@ -400,7 +400,6 @@ void CoinControlDialog::updateLabels(CCoinControl& m_coin_control, WalletModel * // nPayAmount CAmount nPayAmount = 0; bool fDust = false; - CMutableTransaction txDummy; for (const CAmount &amount : CoinControlDialog::payAmounts) { nPayAmount += amount; @@ -409,7 +408,6 @@ void CoinControlDialog::updateLabels(CCoinControl& m_coin_control, WalletModel * { // Assumes a p2pkh script size CTxOut txout(amount, CScript() << std::vector<unsigned char>(24, 0)); - txDummy.vout.push_back(txout); fDust |= IsDust(txout, model->node().getDustRelayFee()); } } @@ -458,7 +456,7 @@ void CoinControlDialog::updateLabels(CCoinControl& m_coin_control, WalletModel * { CPubKey pubkey; PKHash *pkhash = boost::get<PKHash>(&address); - if (pkhash && model->wallet().getPubKey(out.txout.scriptPubKey, CKeyID(*pkhash), pubkey)) + if (pkhash && model->wallet().getPubKey(out.txout.scriptPubKey, ToKeyID(*pkhash), pubkey)) { nBytesInputs += (pubkey.IsCompressed() ? 148 : 180); } diff --git a/src/qt/forms/debugwindow.ui b/src/qt/forms/debugwindow.ui index 8b70800838..93840b4169 100644 --- a/src/qt/forms/debugwindow.ui +++ b/src/qt/forms/debugwindow.ui @@ -1078,22 +1078,16 @@ <height>426</height> </rect> </property> - <property name="minimumSize"> - <size> - <width>300</width> - <height>0</height> - </size> - </property> <layout class="QGridLayout" name="gridLayout_2" columnstretch="0,1"> <item row="0" column="0"> <widget class="QLabel" name="label_30"> <property name="text"> - <string>Whitelisted</string> + <string>Permissions</string> </property> </widget> </item> <item row="0" column="1"> - <widget class="QLabel" name="peerWhitelisted"> + <widget class="QLabel" name="peerPermissions"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> </property> @@ -1270,36 +1264,13 @@ </widget> </item> <item row="8" column="0"> - <widget class="QLabel" name="label_24"> - <property name="text"> - <string>Ban Score</string> - </property> - </widget> - </item> - <item row="8" column="1"> - <widget class="QLabel" name="peerBanScore"> - <property name="cursor"> - <cursorShape>IBeamCursor</cursorShape> - </property> - <property name="text"> - <string>N/A</string> - </property> - <property name="textFormat"> - <enum>Qt::PlainText</enum> - </property> - <property name="textInteractionFlags"> - <set>Qt::LinksAccessibleByMouse|Qt::TextSelectableByKeyboard|Qt::TextSelectableByMouse</set> - </property> - </widget> - </item> - <item row="9" column="0"> <widget class="QLabel" name="label_22"> <property name="text"> <string>Connection Time</string> </property> </widget> </item> - <item row="9" column="1"> + <item row="8" column="1"> <widget class="QLabel" name="peerConnTime"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1315,14 +1286,14 @@ </property> </widget> </item> - <item row="10" column="0"> + <item row="9" column="0"> <widget class="QLabel" name="label_15"> <property name="text"> <string>Last Send</string> </property> </widget> </item> - <item row="10" column="1"> + <item row="9" column="1"> <widget class="QLabel" name="peerLastSend"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1338,14 +1309,14 @@ </property> </widget> </item> - <item row="11" column="0"> + <item row="10" column="0"> <widget class="QLabel" name="label_19"> <property name="text"> <string>Last Receive</string> </property> </widget> </item> - <item row="11" column="1"> + <item row="10" column="1"> <widget class="QLabel" name="peerLastRecv"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1361,14 +1332,14 @@ </property> </widget> </item> - <item row="12" column="0"> + <item row="11" column="0"> <widget class="QLabel" name="label_18"> <property name="text"> <string>Sent</string> </property> </widget> </item> - <item row="12" column="1"> + <item row="11" column="1"> <widget class="QLabel" name="peerBytesSent"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1384,14 +1355,14 @@ </property> </widget> </item> - <item row="13" column="0"> + <item row="12" column="0"> <widget class="QLabel" name="label_20"> <property name="text"> <string>Received</string> </property> </widget> </item> - <item row="13" column="1"> + <item row="12" column="1"> <widget class="QLabel" name="peerBytesRecv"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1407,14 +1378,14 @@ </property> </widget> </item> - <item row="14" column="0"> + <item row="13" column="0"> <widget class="QLabel" name="label_26"> <property name="text"> <string>Ping Time</string> </property> </widget> </item> - <item row="14" column="1"> + <item row="13" column="1"> <widget class="QLabel" name="peerPingTime"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1430,7 +1401,7 @@ </property> </widget> </item> - <item row="15" column="0"> + <item row="14" column="0"> <widget class="QLabel" name="peerPingWaitLabel"> <property name="toolTip"> <string>The duration of a currently outstanding ping.</string> @@ -1440,7 +1411,7 @@ </property> </widget> </item> - <item row="15" column="1"> + <item row="14" column="1"> <widget class="QLabel" name="peerPingWait"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1456,14 +1427,14 @@ </property> </widget> </item> - <item row="16" column="0"> + <item row="15" column="0"> <widget class="QLabel" name="peerMinPingLabel"> <property name="text"> <string>Min Ping</string> </property> </widget> </item> - <item row="16" column="1"> + <item row="15" column="1"> <widget class="QLabel" name="peerMinPing"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1479,14 +1450,14 @@ </property> </widget> </item> - <item row="17" column="0"> + <item row="16" column="0"> <widget class="QLabel" name="label_timeoffset"> <property name="text"> <string>Time Offset</string> </property> </widget> </item> - <item row="17" column="1"> + <item row="16" column="1"> <widget class="QLabel" name="timeoffset"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1502,7 +1473,7 @@ </property> </widget> </item> - <item row="18" column="0"> + <item row="17" column="0"> <widget class="QLabel" name="peerMappedASLabel"> <property name="toolTip"> <string>The mapped Autonomous System used for diversifying peer selection.</string> @@ -1512,7 +1483,7 @@ </property> </widget> </item> - <item row="18" column="1"> + <item row="17" column="1"> <widget class="QLabel" name="peerMappedAS"> <property name="cursor"> <cursorShape>IBeamCursor</cursorShape> @@ -1528,7 +1499,7 @@ </property> </widget> </item> - <item row="19" column="0"> + <item row="18" column="0"> <spacer name="verticalSpacer_3"> <property name="orientation"> <enum>Qt::Vertical</enum> diff --git a/src/qt/forms/optionsdialog.ui b/src/qt/forms/optionsdialog.ui index fea759dee0..0016fb9739 100644 --- a/src/qt/forms/optionsdialog.ui +++ b/src/qt/forms/optionsdialog.ui @@ -459,10 +459,10 @@ <item> <widget class="QCheckBox" name="connectSocksTor"> <property name="toolTip"> - <string>Connect to the Bitcoin network through a separate SOCKS5 proxy for Tor hidden services.</string> + <string>Connect to the Bitcoin network through a separate SOCKS5 proxy for Tor onion services.</string> </property> <property name="text"> - <string>Use separate SOCKS&5 proxy to reach peers via Tor hidden services:</string> + <string>Use separate SOCKS&5 proxy to reach peers via Tor onion services:</string> </property> </widget> </item> diff --git a/src/qt/forms/psbtoperationsdialog.ui b/src/qt/forms/psbtoperationsdialog.ui new file mode 100644 index 0000000000..c2e2f5035b --- /dev/null +++ b/src/qt/forms/psbtoperationsdialog.ui @@ -0,0 +1,148 @@ +<?xml version="1.0" encoding="UTF-8"?> +<ui version="4.0"> + <class>PSBTOperationsDialog</class> + <widget class="QDialog" name="PSBTOperationsDialog"> + <property name="geometry"> + <rect> + <x>0</x> + <y>0</y> + <width>585</width> + <height>327</height> + </rect> + </property> + <property name="windowTitle"> + <string>Dialog</string> + </property> + <layout class="QVBoxLayout" name="verticalLayout"> + <property name="spacing"> + <number>12</number> + </property> + <property name="sizeConstraint"> + <enum>QLayout::SetDefaultConstraint</enum> + </property> + <property name="bottomMargin"> + <number>12</number> + </property> + <item> + <layout class="QVBoxLayout" name="mainDialogLayout"> + <property name="spacing"> + <number>5</number> + </property> + <property name="topMargin"> + <number>0</number> + </property> + <property name="bottomMargin"> + <number>0</number> + </property> + <item> + <widget class="QLabel" name="statusBar"> + <property name="font"> + <font> + <weight>75</weight> + <bold>true</bold> + </font> + </property> + <property name="autoFillBackground"> + <bool>false</bool> + </property> + <property name="styleSheet"> + <string notr="true"/> + </property> + <property name="text"> + <string/> + </property> + </widget> + </item> + <item> + <widget class="QTextEdit" name="transactionDescription"> + <property name="undoRedoEnabled"> + <bool>false</bool> + </property> + <property name="readOnly"> + <bool>true</bool> + </property> + </widget> + </item> + <item> + <layout class="QHBoxLayout" name="buttonRowLayout"> + <property name="spacing"> + <number>5</number> + </property> + <item> + <widget class="QPushButton" name="signTransactionButton"> + <property name="sizePolicy"> + <sizepolicy hsizetype="Minimum" vsizetype="Fixed"> + <horstretch>0</horstretch> + <verstretch>0</verstretch> + </sizepolicy> + </property> + <property name="font"> + <font> + <weight>50</weight> + <bold>false</bold> + </font> + </property> + <property name="text"> + <string>Sign Tx</string> + </property> + <property name="autoDefault"> + <bool>true</bool> + </property> + <property name="default"> + <bool>false</bool> + </property> + <property name="flat"> + <bool>false</bool> + </property> + </widget> + </item> + <item> + <widget class="QPushButton" name="broadcastTransactionButton"> + <property name="text"> + <string>Broadcast Tx</string> + </property> + </widget> + </item> + <item> + <spacer name="horizontalSpacer"> + <property name="orientation"> + <enum>Qt::Horizontal</enum> + </property> + <property name="sizeHint" stdset="0"> + <size> + <width>40</width> + <height>20</height> + </size> + </property> + </spacer> + </item> + <item> + <widget class="QPushButton" name="copyToClipboardButton"> + <property name="text"> + <string>Copy to Clipboard</string> + </property> + </widget> + </item> + <item> + <widget class="QPushButton" name="saveButton"> + <property name="text"> + <string>Save...</string> + </property> + </widget> + </item> + <item> + <widget class="QPushButton" name="closeButton"> + <property name="text"> + <string>Close</string> + </property> + </widget> + </item> + </layout> + </item> + </layout> + </item> + </layout> + </widget> + <resources/> + <connections/> +</ui> diff --git a/src/qt/guiutil.cpp b/src/qt/guiutil.cpp index ce44d4f3a5..7f439fa45e 100644 --- a/src/qt/guiutil.cpp +++ b/src/qt/guiutil.cpp @@ -188,7 +188,7 @@ QString formatBitcoinURI(const SendCoinsRecipient &info) if (info.amount) { - ret += QString("?amount=%1").arg(BitcoinUnits::format(BitcoinUnits::BTC, info.amount, false, BitcoinUnits::separatorNever)); + ret += QString("?amount=%1").arg(BitcoinUnits::format(BitcoinUnits::BTC, info.amount, false, BitcoinUnits::SeparatorStyle::NEVER)); paramCount++; } @@ -450,6 +450,28 @@ bool ToolTipToRichTextFilter::eventFilter(QObject *obj, QEvent *evt) return QObject::eventFilter(obj, evt); } +LabelOutOfFocusEventFilter::LabelOutOfFocusEventFilter(QObject* parent) + : QObject(parent) +{ +} + +bool LabelOutOfFocusEventFilter::eventFilter(QObject* watched, QEvent* event) +{ + if (event->type() == QEvent::FocusOut) { + auto focus_out = static_cast<QFocusEvent*>(event); + if (focus_out->reason() != Qt::PopupFocusReason) { + auto label = qobject_cast<QLabel*>(watched); + if (label) { + auto flags = label->textInteractionFlags(); + label->setTextInteractionFlags(Qt::NoTextInteraction); + label->setTextInteractionFlags(flags); + } + } + } + + return QObject::eventFilter(watched, event); +} + void TableViewLastColumnResizingFixer::connectViewHeadersSignals() { connect(tableView->horizontalHeader(), &QHeaderView::sectionResized, this, &TableViewLastColumnResizingFixer::on_sectionResized); diff --git a/src/qt/guiutil.h b/src/qt/guiutil.h index 8741d90102..2bd94b5eb3 100644 --- a/src/qt/guiutil.h +++ b/src/qt/guiutil.h @@ -162,6 +162,21 @@ namespace GUIUtil }; /** + * Qt event filter that intercepts QEvent::FocusOut events for QLabel objects, and + * resets their `textInteractionFlags' property to get rid of the visible cursor. + * + * This is a temporary fix of QTBUG-59514. + */ + class LabelOutOfFocusEventFilter : public QObject + { + Q_OBJECT + + public: + explicit LabelOutOfFocusEventFilter(QObject* parent); + bool eventFilter(QObject* watched, QEvent* event) override; + }; + + /** * Makes a QTableView last column feel as if it was being resized from its left border. * Also makes sure the column widths are never larger than the table's viewport. * In Qt, all columns are resizable from the right, but it's not intuitive resizing the last column from the right. diff --git a/src/qt/modaloverlay.cpp b/src/qt/modaloverlay.cpp index 0ba1beaf3e..8070aa627c 100644 --- a/src/qt/modaloverlay.cpp +++ b/src/qt/modaloverlay.cpp @@ -171,6 +171,8 @@ void ModalOverlay::showHide(bool hide, bool userRequested) if ( (layerIsVisible && !hide) || (!layerIsVisible && hide) || (!hide && userClosed && !userRequested)) return; + Q_EMIT triggered(hide); + if (!isVisible() && !hide) setVisible(true); diff --git a/src/qt/modaloverlay.h b/src/qt/modaloverlay.h index 1d84046d3d..7b07777641 100644 --- a/src/qt/modaloverlay.h +++ b/src/qt/modaloverlay.h @@ -25,16 +25,20 @@ public: explicit ModalOverlay(bool enable_wallet, QWidget *parent); ~ModalOverlay(); -public Q_SLOTS: void tipUpdate(int count, const QDateTime& blockDate, double nVerificationProgress); void setKnownBestHeight(int count, const QDateTime& blockDate); - void toggleVisibility(); // will show or hide the modal layer void showHide(bool hide = false, bool userRequested = false); - void closeClicked(); bool isLayerVisible() const { return layerIsVisible; } +public Q_SLOTS: + void toggleVisibility(); + void closeClicked(); + +Q_SIGNALS: + void triggered(bool hidden); + protected: bool eventFilter(QObject * obj, QEvent * ev) override; bool event(QEvent* ev) override; diff --git a/src/qt/optionsmodel.h b/src/qt/optionsmodel.h index 6ca5ac9d75..14fdf9046e 100644 --- a/src/qt/optionsmodel.h +++ b/src/qt/optionsmodel.h @@ -6,6 +6,7 @@ #define BITCOIN_QT_OPTIONSMODEL_H #include <amount.h> +#include <cstdint> #include <qt/guiconstants.h> #include <QAbstractListModel> @@ -15,7 +16,7 @@ class Node; } extern const char *DEFAULT_GUI_PROXY_HOST; -static constexpr unsigned short DEFAULT_GUI_PROXY_PORT = 9050; +static constexpr uint16_t DEFAULT_GUI_PROXY_PORT = 9050; /** * Convert configured prune target MiB to displayed GB. Round up to avoid underestimating max disk usage. diff --git a/src/qt/overviewpage.cpp b/src/qt/overviewpage.cpp index 0af70f2735..b536567c8b 100644 --- a/src/qt/overviewpage.cpp +++ b/src/qt/overviewpage.cpp @@ -88,7 +88,7 @@ public: foreground = option.palette.color(QPalette::Text); } painter->setPen(foreground); - QString amountText = BitcoinUnits::formatWithUnit(unit, amount, true, BitcoinUnits::separatorAlways); + QString amountText = BitcoinUnits::formatWithUnit(unit, amount, true, BitcoinUnits::SeparatorStyle::ALWAYS); if(!confirmed) { amountText = QString("[") + amountText + QString("]"); @@ -180,25 +180,25 @@ void OverviewPage::setBalance(const interfaces::WalletBalances& balances) m_balances = balances; if (walletModel->wallet().isLegacy()) { if (walletModel->wallet().privateKeysDisabled()) { - ui->labelBalance->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelUnconfirmed->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance + balances.unconfirmed_watch_only_balance + balances.immature_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelBalance->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelUnconfirmed->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_watch_only_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_watch_only_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance + balances.unconfirmed_watch_only_balance + balances.immature_watch_only_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); } else { - ui->labelBalance->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelUnconfirmed->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance + balances.unconfirmed_balance + balances.immature_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelWatchAvailable->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelWatchPending->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelWatchImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelWatchTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance + balances.unconfirmed_watch_only_balance + balances.immature_watch_only_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelBalance->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelUnconfirmed->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance + balances.unconfirmed_balance + balances.immature_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelWatchAvailable->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelWatchPending->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_watch_only_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelWatchImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_watch_only_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelWatchTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.watch_only_balance + balances.unconfirmed_watch_only_balance + balances.immature_watch_only_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); } } else { - ui->labelBalance->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelUnconfirmed->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_balance, BitcoinUnits::separatorAlways, m_privacy)); - ui->labelTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance + balances.unconfirmed_balance + balances.immature_balance, BitcoinUnits::separatorAlways, m_privacy)); + ui->labelBalance->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelUnconfirmed->setText(BitcoinUnits::formatWithPrivacy(unit, balances.unconfirmed_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelImmature->setText(BitcoinUnits::formatWithPrivacy(unit, balances.immature_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); + ui->labelTotal->setText(BitcoinUnits::formatWithPrivacy(unit, balances.balance + balances.unconfirmed_balance + balances.immature_balance, BitcoinUnits::SeparatorStyle::ALWAYS, m_privacy)); } // only show immature (newly mined) balance if it's non-zero, so as not to complicate things // for the non-mining users diff --git a/src/qt/paymentserver.cpp b/src/qt/paymentserver.cpp index beca78a021..a1da85bda7 100644 --- a/src/qt/paymentserver.cpp +++ b/src/qt/paymentserver.cpp @@ -14,9 +14,9 @@ #include <chainparams.h> #include <interfaces/node.h> -#include <policy/policy.h> #include <key_io.h> -#include <ui_interface.h> +#include <node/ui_interface.h> +#include <policy/policy.h> #include <util/system.h> #include <wallet/wallet.h> diff --git a/src/qt/psbtoperationsdialog.cpp b/src/qt/psbtoperationsdialog.cpp new file mode 100644 index 0000000000..58167d4bb4 --- /dev/null +++ b/src/qt/psbtoperationsdialog.cpp @@ -0,0 +1,268 @@ +// Copyright (c) 2011-2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <qt/psbtoperationsdialog.h> + +#include <core_io.h> +#include <interfaces/node.h> +#include <key_io.h> +#include <node/psbt.h> +#include <policy/policy.h> +#include <qt/bitcoinunits.h> +#include <qt/forms/ui_psbtoperationsdialog.h> +#include <qt/guiutil.h> +#include <qt/optionsmodel.h> +#include <util/strencodings.h> + +#include <iostream> + + +PSBTOperationsDialog::PSBTOperationsDialog( + QWidget* parent, WalletModel* wallet_model, ClientModel* client_model) : QDialog(parent), + m_ui(new Ui::PSBTOperationsDialog), + m_wallet_model(wallet_model), + m_client_model(client_model) +{ + m_ui->setupUi(this); + setWindowTitle("PSBT Operations"); + + connect(m_ui->signTransactionButton, &QPushButton::clicked, this, &PSBTOperationsDialog::signTransaction); + connect(m_ui->broadcastTransactionButton, &QPushButton::clicked, this, &PSBTOperationsDialog::broadcastTransaction); + connect(m_ui->copyToClipboardButton, &QPushButton::clicked, this, &PSBTOperationsDialog::copyToClipboard); + connect(m_ui->saveButton, &QPushButton::clicked, this, &PSBTOperationsDialog::saveTransaction); + + connect(m_ui->closeButton, &QPushButton::clicked, this, &PSBTOperationsDialog::close); + + m_ui->signTransactionButton->setEnabled(false); + m_ui->broadcastTransactionButton->setEnabled(false); +} + +PSBTOperationsDialog::~PSBTOperationsDialog() +{ + delete m_ui; +} + +void PSBTOperationsDialog::openWithPSBT(PartiallySignedTransaction psbtx) +{ + m_transaction_data = psbtx; + + bool complete; + size_t n_could_sign; + FinalizePSBT(psbtx); // Make sure all existing signatures are fully combined before checking for completeness. + TransactionError err = m_wallet_model->wallet().fillPSBT(SIGHASH_ALL, false /* sign */, true /* bip32derivs */, m_transaction_data, complete, &n_could_sign); + if (err != TransactionError::OK) { + showStatus(tr("Failed to load transaction: %1") + .arg(QString::fromStdString(TransactionErrorString(err).translated)), StatusLevel::ERR); + return; + } + + m_ui->broadcastTransactionButton->setEnabled(complete); + m_ui->signTransactionButton->setEnabled(!complete && !m_wallet_model->wallet().privateKeysDisabled() && n_could_sign > 0); + + updateTransactionDisplay(); +} + +void PSBTOperationsDialog::signTransaction() +{ + bool complete; + size_t n_signed; + TransactionError err = m_wallet_model->wallet().fillPSBT(SIGHASH_ALL, true /* sign */, true /* bip32derivs */, m_transaction_data, complete, &n_signed); + + if (err != TransactionError::OK) { + showStatus(tr("Failed to sign transaction: %1") + .arg(QString::fromStdString(TransactionErrorString(err).translated)), StatusLevel::ERR); + return; + } + + updateTransactionDisplay(); + + if (!complete && n_signed < 1) { + showStatus(tr("Could not sign any more inputs."), StatusLevel::WARN); + } else if (!complete) { + showStatus(tr("Signed %1 inputs, but more signatures are still required.").arg(n_signed), + StatusLevel::INFO); + } else { + showStatus(tr("Signed transaction successfully. Transaction is ready to broadcast."), + StatusLevel::INFO); + m_ui->broadcastTransactionButton->setEnabled(true); + } +} + +void PSBTOperationsDialog::broadcastTransaction() +{ + CMutableTransaction mtx; + if (!FinalizeAndExtractPSBT(m_transaction_data, mtx)) { + // This is never expected to fail unless we were given a malformed PSBT + // (e.g. with an invalid signature.) + showStatus(tr("Unknown error processing transaction."), StatusLevel::ERR); + return; + } + + CTransactionRef tx = MakeTransactionRef(mtx); + std::string err_string; + TransactionError error = BroadcastTransaction( + *m_client_model->node().context(), tx, err_string, DEFAULT_MAX_RAW_TX_FEE_RATE.GetFeePerK(), /* relay */ true, /* await_callback */ false); + + if (error == TransactionError::OK) { + showStatus(tr("Transaction broadcast successfully! Transaction ID: %1") + .arg(QString::fromStdString(tx->GetHash().GetHex())), StatusLevel::INFO); + } else { + showStatus(tr("Transaction broadcast failed: %1") + .arg(QString::fromStdString(TransactionErrorString(error).translated)), StatusLevel::ERR); + } +} + +void PSBTOperationsDialog::copyToClipboard() { + CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION); + ssTx << m_transaction_data; + GUIUtil::setClipboard(EncodeBase64(ssTx.str()).c_str()); + showStatus(tr("PSBT copied to clipboard."), StatusLevel::INFO); +} + +void PSBTOperationsDialog::saveTransaction() { + CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION); + ssTx << m_transaction_data; + + QString selected_filter; + QString filename_suggestion = ""; + bool first = true; + for (const CTxOut& out : m_transaction_data.tx->vout) { + if (!first) { + filename_suggestion.append("-"); + } + CTxDestination address; + ExtractDestination(out.scriptPubKey, address); + QString amount = BitcoinUnits::format(m_wallet_model->getOptionsModel()->getDisplayUnit(), out.nValue); + QString address_str = QString::fromStdString(EncodeDestination(address)); + filename_suggestion.append(address_str + "-" + amount); + first = false; + } + filename_suggestion.append(".psbt"); + QString filename = GUIUtil::getSaveFileName(this, + tr("Save Transaction Data"), filename_suggestion, + tr("Partially Signed Transaction (Binary) (*.psbt)"), &selected_filter); + if (filename.isEmpty()) { + return; + } + std::ofstream out(filename.toLocal8Bit().data()); + out << ssTx.str(); + out.close(); + showStatus(tr("PSBT saved to disk."), StatusLevel::INFO); +} + +void PSBTOperationsDialog::updateTransactionDisplay() { + m_ui->transactionDescription->setText(QString::fromStdString(renderTransaction(m_transaction_data))); + showTransactionStatus(m_transaction_data); +} + +std::string PSBTOperationsDialog::renderTransaction(const PartiallySignedTransaction &psbtx) +{ + QString tx_description = ""; + CAmount totalAmount = 0; + for (const CTxOut& out : psbtx.tx->vout) { + CTxDestination address; + ExtractDestination(out.scriptPubKey, address); + totalAmount += out.nValue; + tx_description.append(tr(" * Sends %1 to %2") + .arg(BitcoinUnits::formatWithUnit(BitcoinUnits::BTC, out.nValue)) + .arg(QString::fromStdString(EncodeDestination(address)))); + tx_description.append("<br>"); + } + + PSBTAnalysis analysis = AnalyzePSBT(psbtx); + tx_description.append(" * "); + if (!*analysis.fee) { + // This happens if the transaction is missing input UTXO information. + tx_description.append(tr("Unable to calculate transaction fee or total transaction amount.")); + } else { + tx_description.append(tr("Pays transaction fee: ")); + tx_description.append(BitcoinUnits::formatWithUnit(BitcoinUnits::BTC, *analysis.fee)); + + // add total amount in all subdivision units + tx_description.append("<hr />"); + QStringList alternativeUnits; + for (const BitcoinUnits::Unit u : BitcoinUnits::availableUnits()) + { + if(u != m_client_model->getOptionsModel()->getDisplayUnit()) { + alternativeUnits.append(BitcoinUnits::formatHtmlWithUnit(u, totalAmount)); + } + } + tx_description.append(QString("<b>%1</b>: <b>%2</b>").arg(tr("Total Amount")) + .arg(BitcoinUnits::formatHtmlWithUnit(m_client_model->getOptionsModel()->getDisplayUnit(), totalAmount))); + tx_description.append(QString("<br /><span style='font-size:10pt; font-weight:normal;'>(=%1)</span>") + .arg(alternativeUnits.join(" " + tr("or") + " "))); + } + + size_t num_unsigned = CountPSBTUnsignedInputs(psbtx); + if (num_unsigned > 0) { + tx_description.append("<br><br>"); + tx_description.append(tr("Transaction has %1 unsigned inputs.").arg(QString::number(num_unsigned))); + } + + return tx_description.toStdString(); +} + +void PSBTOperationsDialog::showStatus(const QString &msg, StatusLevel level) { + m_ui->statusBar->setText(msg); + switch (level) { + case StatusLevel::INFO: { + m_ui->statusBar->setStyleSheet("QLabel { background-color : lightgreen }"); + break; + } + case StatusLevel::WARN: { + m_ui->statusBar->setStyleSheet("QLabel { background-color : orange }"); + break; + } + case StatusLevel::ERR: { + m_ui->statusBar->setStyleSheet("QLabel { background-color : red }"); + break; + } + } + m_ui->statusBar->show(); +} + +size_t PSBTOperationsDialog::couldSignInputs(const PartiallySignedTransaction &psbtx) { + size_t n_signed; + bool complete; + TransactionError err = m_wallet_model->wallet().fillPSBT(SIGHASH_ALL, false /* sign */, false /* bip32derivs */, m_transaction_data, complete, &n_signed); + + if (err != TransactionError::OK) { + return 0; + } + return n_signed; +} + +void PSBTOperationsDialog::showTransactionStatus(const PartiallySignedTransaction &psbtx) { + PSBTAnalysis analysis = AnalyzePSBT(psbtx); + size_t n_could_sign = couldSignInputs(psbtx); + + switch (analysis.next) { + case PSBTRole::UPDATER: { + showStatus(tr("Transaction is missing some information about inputs."), StatusLevel::WARN); + break; + } + case PSBTRole::SIGNER: { + QString need_sig_text = tr("Transaction still needs signature(s)."); + StatusLevel level = StatusLevel::INFO; + if (m_wallet_model->wallet().privateKeysDisabled()) { + need_sig_text += " " + tr("(But this wallet cannot sign transactions.)"); + level = StatusLevel::WARN; + } else if (n_could_sign < 1) { + need_sig_text += " " + tr("(But this wallet does not have the right keys.)"); // XXX wording + level = StatusLevel::WARN; + } + showStatus(need_sig_text, level); + break; + } + case PSBTRole::FINALIZER: + case PSBTRole::EXTRACTOR: { + showStatus(tr("Transaction is fully signed and ready for broadcast."), StatusLevel::INFO); + break; + } + default: { + showStatus(tr("Transaction status is unknown."), StatusLevel::ERR); + break; + } + } +} diff --git a/src/qt/psbtoperationsdialog.h b/src/qt/psbtoperationsdialog.h new file mode 100644 index 0000000000..f37bdbe39a --- /dev/null +++ b/src/qt/psbtoperationsdialog.h @@ -0,0 +1,54 @@ +// Copyright (c) 2011-2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_QT_PSBTOPERATIONSDIALOG_H +#define BITCOIN_QT_PSBTOPERATIONSDIALOG_H + +#include <QDialog> + +#include <psbt.h> +#include <qt/clientmodel.h> +#include <qt/walletmodel.h> + +namespace Ui { +class PSBTOperationsDialog; +} + +/** Dialog showing transaction details. */ +class PSBTOperationsDialog : public QDialog +{ + Q_OBJECT + +public: + explicit PSBTOperationsDialog(QWidget* parent, WalletModel* walletModel, ClientModel* clientModel); + ~PSBTOperationsDialog(); + + void openWithPSBT(PartiallySignedTransaction psbtx); + +public Q_SLOTS: + void signTransaction(); + void broadcastTransaction(); + void copyToClipboard(); + void saveTransaction(); + +private: + Ui::PSBTOperationsDialog* m_ui; + PartiallySignedTransaction m_transaction_data; + WalletModel* m_wallet_model; + ClientModel* m_client_model; + + enum class StatusLevel { + INFO, + WARN, + ERR + }; + + size_t couldSignInputs(const PartiallySignedTransaction &psbtx); + void updateTransactionDisplay(); + std::string renderTransaction(const PartiallySignedTransaction &psbtx); + void showStatus(const QString &msg, StatusLevel level); + void showTransactionStatus(const PartiallySignedTransaction &psbtx); +}; + +#endif // BITCOIN_QT_PSBTOPERATIONSDIALOG_H diff --git a/src/qt/recentrequeststablemodel.cpp b/src/qt/recentrequeststablemodel.cpp index 7419297a96..3e20368a36 100644 --- a/src/qt/recentrequeststablemodel.cpp +++ b/src/qt/recentrequeststablemodel.cpp @@ -82,7 +82,7 @@ QVariant RecentRequestsTableModel::data(const QModelIndex &index, int role) cons if (rec->recipient.amount == 0 && role == Qt::DisplayRole) return tr("(no amount requested)"); else if (role == Qt::EditRole) - return BitcoinUnits::format(walletModel->getOptionsModel()->getDisplayUnit(), rec->recipient.amount, false, BitcoinUnits::separatorNever); + return BitcoinUnits::format(walletModel->getOptionsModel()->getDisplayUnit(), rec->recipient.amount, false, BitcoinUnits::SeparatorStyle::NEVER); else return BitcoinUnits::format(walletModel->getOptionsModel()->getDisplayUnit(), rec->recipient.amount); } diff --git a/src/qt/res/movies/makespinner.sh b/src/qt/res/animation/makespinner.sh index 4fa8dadf86..4fa8dadf86 100755 --- a/src/qt/res/movies/makespinner.sh +++ b/src/qt/res/animation/makespinner.sh diff --git a/src/qt/res/movies/spinner-000.png b/src/qt/res/animation/spinner-000.png Binary files differindex 0dc48d0d8c..0dc48d0d8c 100644 --- a/src/qt/res/movies/spinner-000.png +++ b/src/qt/res/animation/spinner-000.png diff --git a/src/qt/res/movies/spinner-001.png b/src/qt/res/animation/spinner-001.png Binary files differindex d167f20541..d167f20541 100644 --- a/src/qt/res/movies/spinner-001.png +++ b/src/qt/res/animation/spinner-001.png diff --git a/src/qt/res/movies/spinner-002.png b/src/qt/res/animation/spinner-002.png Binary files differindex 4a1f1f8e56..4a1f1f8e56 100644 --- a/src/qt/res/movies/spinner-002.png +++ b/src/qt/res/animation/spinner-002.png diff --git a/src/qt/res/movies/spinner-003.png b/src/qt/res/animation/spinner-003.png Binary files differindex fb1c2cd4ad..fb1c2cd4ad 100644 --- a/src/qt/res/movies/spinner-003.png +++ b/src/qt/res/animation/spinner-003.png diff --git a/src/qt/res/movies/spinner-004.png b/src/qt/res/animation/spinner-004.png Binary files differindex 4df2132344..4df2132344 100644 --- a/src/qt/res/movies/spinner-004.png +++ b/src/qt/res/animation/spinner-004.png diff --git a/src/qt/res/movies/spinner-005.png b/src/qt/res/animation/spinner-005.png Binary files differindex 5d6f41e0dc..5d6f41e0dc 100644 --- a/src/qt/res/movies/spinner-005.png +++ b/src/qt/res/animation/spinner-005.png diff --git a/src/qt/res/movies/spinner-006.png b/src/qt/res/animation/spinner-006.png Binary files differindex c1f7d18899..c1f7d18899 100644 --- a/src/qt/res/movies/spinner-006.png +++ b/src/qt/res/animation/spinner-006.png diff --git a/src/qt/res/movies/spinner-007.png b/src/qt/res/animation/spinner-007.png Binary files differindex 1e794b2626..1e794b2626 100644 --- a/src/qt/res/movies/spinner-007.png +++ b/src/qt/res/animation/spinner-007.png diff --git a/src/qt/res/movies/spinner-008.png b/src/qt/res/animation/spinner-008.png Binary files differindex df12ea8719..df12ea8719 100644 --- a/src/qt/res/movies/spinner-008.png +++ b/src/qt/res/animation/spinner-008.png diff --git a/src/qt/res/movies/spinner-009.png b/src/qt/res/animation/spinner-009.png Binary files differindex 18fc3a7d16..18fc3a7d16 100644 --- a/src/qt/res/movies/spinner-009.png +++ b/src/qt/res/animation/spinner-009.png diff --git a/src/qt/res/movies/spinner-010.png b/src/qt/res/animation/spinner-010.png Binary files differindex a79c845fe8..a79c845fe8 100644 --- a/src/qt/res/movies/spinner-010.png +++ b/src/qt/res/animation/spinner-010.png diff --git a/src/qt/res/movies/spinner-011.png b/src/qt/res/animation/spinner-011.png Binary files differindex 57baf66895..57baf66895 100644 --- a/src/qt/res/movies/spinner-011.png +++ b/src/qt/res/animation/spinner-011.png diff --git a/src/qt/res/movies/spinner-012.png b/src/qt/res/animation/spinner-012.png Binary files differindex 9deae7853a..9deae7853a 100644 --- a/src/qt/res/movies/spinner-012.png +++ b/src/qt/res/animation/spinner-012.png diff --git a/src/qt/res/movies/spinner-013.png b/src/qt/res/animation/spinner-013.png Binary files differindex 0659d48dec..0659d48dec 100644 --- a/src/qt/res/movies/spinner-013.png +++ b/src/qt/res/animation/spinner-013.png diff --git a/src/qt/res/movies/spinner-014.png b/src/qt/res/animation/spinner-014.png Binary files differindex bc1ef51bde..bc1ef51bde 100644 --- a/src/qt/res/movies/spinner-014.png +++ b/src/qt/res/animation/spinner-014.png diff --git a/src/qt/res/movies/spinner-015.png b/src/qt/res/animation/spinner-015.png Binary files differindex 24b57b62c2..24b57b62c2 100644 --- a/src/qt/res/movies/spinner-015.png +++ b/src/qt/res/animation/spinner-015.png diff --git a/src/qt/res/movies/spinner-016.png b/src/qt/res/animation/spinner-016.png Binary files differindex d622872651..d622872651 100644 --- a/src/qt/res/movies/spinner-016.png +++ b/src/qt/res/animation/spinner-016.png diff --git a/src/qt/res/movies/spinner-017.png b/src/qt/res/animation/spinner-017.png Binary files differindex f48f688db2..f48f688db2 100644 --- a/src/qt/res/movies/spinner-017.png +++ b/src/qt/res/animation/spinner-017.png diff --git a/src/qt/res/movies/spinner-018.png b/src/qt/res/animation/spinner-018.png Binary files differindex a2c8f38b1d..a2c8f38b1d 100644 --- a/src/qt/res/movies/spinner-018.png +++ b/src/qt/res/animation/spinner-018.png diff --git a/src/qt/res/movies/spinner-019.png b/src/qt/res/animation/spinner-019.png Binary files differindex 9d7cc35d82..9d7cc35d82 100644 --- a/src/qt/res/movies/spinner-019.png +++ b/src/qt/res/animation/spinner-019.png diff --git a/src/qt/res/movies/spinner-020.png b/src/qt/res/animation/spinner-020.png Binary files differindex 1a07acc454..1a07acc454 100644 --- a/src/qt/res/movies/spinner-020.png +++ b/src/qt/res/animation/spinner-020.png diff --git a/src/qt/res/movies/spinner-021.png b/src/qt/res/animation/spinner-021.png Binary files differindex 9cea8f2543..9cea8f2543 100644 --- a/src/qt/res/movies/spinner-021.png +++ b/src/qt/res/animation/spinner-021.png diff --git a/src/qt/res/movies/spinner-022.png b/src/qt/res/animation/spinner-022.png Binary files differindex 60250f6dea..60250f6dea 100644 --- a/src/qt/res/movies/spinner-022.png +++ b/src/qt/res/animation/spinner-022.png diff --git a/src/qt/res/movies/spinner-023.png b/src/qt/res/animation/spinner-023.png Binary files differindex fc290a0cf2..fc290a0cf2 100644 --- a/src/qt/res/movies/spinner-023.png +++ b/src/qt/res/animation/spinner-023.png diff --git a/src/qt/res/movies/spinner-024.png b/src/qt/res/animation/spinner-024.png Binary files differindex c5dcf1eae9..c5dcf1eae9 100644 --- a/src/qt/res/movies/spinner-024.png +++ b/src/qt/res/animation/spinner-024.png diff --git a/src/qt/res/movies/spinner-025.png b/src/qt/res/animation/spinner-025.png Binary files differindex 7f3577a4de..7f3577a4de 100644 --- a/src/qt/res/movies/spinner-025.png +++ b/src/qt/res/animation/spinner-025.png diff --git a/src/qt/res/movies/spinner-026.png b/src/qt/res/animation/spinner-026.png Binary files differindex 1663ddf44c..1663ddf44c 100644 --- a/src/qt/res/movies/spinner-026.png +++ b/src/qt/res/animation/spinner-026.png diff --git a/src/qt/res/movies/spinner-027.png b/src/qt/res/animation/spinner-027.png Binary files differindex d0e6da4503..d0e6da4503 100644 --- a/src/qt/res/movies/spinner-027.png +++ b/src/qt/res/animation/spinner-027.png diff --git a/src/qt/res/movies/spinner-028.png b/src/qt/res/animation/spinner-028.png Binary files differindex 2a7aba50e2..2a7aba50e2 100644 --- a/src/qt/res/movies/spinner-028.png +++ b/src/qt/res/animation/spinner-028.png diff --git a/src/qt/res/movies/spinner-029.png b/src/qt/res/animation/spinner-029.png Binary files differindex c8ca15c1e1..c8ca15c1e1 100644 --- a/src/qt/res/movies/spinner-029.png +++ b/src/qt/res/animation/spinner-029.png diff --git a/src/qt/res/movies/spinner-030.png b/src/qt/res/animation/spinner-030.png Binary files differindex c847c99a93..c847c99a93 100644 --- a/src/qt/res/movies/spinner-030.png +++ b/src/qt/res/animation/spinner-030.png diff --git a/src/qt/res/movies/spinner-031.png b/src/qt/res/animation/spinner-031.png Binary files differindex 403443144e..403443144e 100644 --- a/src/qt/res/movies/spinner-031.png +++ b/src/qt/res/animation/spinner-031.png diff --git a/src/qt/res/movies/spinner-032.png b/src/qt/res/animation/spinner-032.png Binary files differindex f9db080567..f9db080567 100644 --- a/src/qt/res/movies/spinner-032.png +++ b/src/qt/res/animation/spinner-032.png diff --git a/src/qt/res/movies/spinner-033.png b/src/qt/res/animation/spinner-033.png Binary files differindex 43f57719e7..43f57719e7 100644 --- a/src/qt/res/movies/spinner-033.png +++ b/src/qt/res/animation/spinner-033.png diff --git a/src/qt/res/movies/spinner-034.png b/src/qt/res/animation/spinner-034.png Binary files differindex c26656ff17..c26656ff17 100644 --- a/src/qt/res/movies/spinner-034.png +++ b/src/qt/res/animation/spinner-034.png diff --git a/src/qt/res/movies/spinner-035.png b/src/qt/res/animation/spinner-035.png Binary files differindex e471f950a3..e471f950a3 100644 --- a/src/qt/res/movies/spinner-035.png +++ b/src/qt/res/animation/spinner-035.png diff --git a/src/qt/rpcconsole.cpp b/src/qt/rpcconsole.cpp index 0f89d4e6fe..a14fae6460 100644 --- a/src/qt/rpcconsole.cpp +++ b/src/qt/rpcconsole.cpp @@ -466,6 +466,7 @@ RPCConsole::RPCConsole(interfaces::Node& node, const PlatformStyle *_platformSty // Install event filter for up and down arrow ui->lineEdit->installEventFilter(this); + ui->lineEdit->setMaxLength(16 * 1024 * 1024); ui->messagesWidget->installEventFilter(this); connect(ui->clearButton, &QPushButton::clicked, this, &RPCConsole::clear); @@ -555,7 +556,7 @@ bool RPCConsole::eventFilter(QObject* obj, QEvent *event) return QWidget::eventFilter(obj, event); } -void RPCConsole::setClientModel(ClientModel *model) +void RPCConsole::setClientModel(ClientModel *model, int bestblock_height, int64_t bestblock_date, double verification_progress) { clientModel = model; @@ -575,13 +576,13 @@ void RPCConsole::setClientModel(ClientModel *model) setNumConnections(model->getNumConnections()); connect(model, &ClientModel::numConnectionsChanged, this, &RPCConsole::setNumConnections); - interfaces::Node& node = clientModel->node(); - setNumBlocks(node.getNumBlocks(), QDateTime::fromTime_t(node.getLastBlockTime()), node.getVerificationProgress(), false); + setNumBlocks(bestblock_height, QDateTime::fromTime_t(bestblock_date), verification_progress, false); connect(model, &ClientModel::numBlocksChanged, this, &RPCConsole::setNumBlocks); updateNetworkState(); connect(model, &ClientModel::networkActiveChanged, this, &RPCConsole::setNetworkActive); + interfaces::Node& node = clientModel->node(); updateTrafficStats(node.getTotalBytesRecv(), node.getTotalBytesSent()); connect(model, &ClientModel::bytesChanged, this, &RPCConsole::updateTrafficStats); @@ -1118,15 +1119,20 @@ void RPCConsole::updateNodeDetail(const CNodeCombinedStats *stats) ui->peerSubversion->setText(QString::fromStdString(stats->nodeStats.cleanSubVer)); ui->peerDirection->setText(stats->nodeStats.fInbound ? tr("Inbound") : tr("Outbound")); ui->peerHeight->setText(QString::number(stats->nodeStats.nStartingHeight)); - ui->peerWhitelisted->setText(stats->nodeStats.m_legacyWhitelisted ? tr("Yes") : tr("No")); + if (stats->nodeStats.m_permissionFlags == PF_NONE) { + ui->peerPermissions->setText(tr("N/A")); + } else { + QStringList permissions; + for (const auto& permission : NetPermissions::ToStrings(stats->nodeStats.m_permissionFlags)) { + permissions.append(QString::fromStdString(permission)); + } + ui->peerPermissions->setText(permissions.join(" & ")); + } ui->peerMappedAS->setText(stats->nodeStats.m_mapped_as != 0 ? QString::number(stats->nodeStats.m_mapped_as) : tr("N/A")); // This check fails for example if the lock was busy and // nodeStateStats couldn't be fetched. if (stats->fNodeStateStatsAvailable) { - // Ban score is init to 0 - ui->peerBanScore->setText(QString("%1").arg(stats->nodeStateStats.nMisbehavior)); - // Sync height is init to -1 if (stats->nodeStateStats.nSyncHeight > -1) ui->peerSyncHeight->setText(QString("%1").arg(stats->nodeStateStats.nSyncHeight)); @@ -1217,7 +1223,7 @@ void RPCConsole::banSelectedNode(int bantime) // Find possible nodes, ban it and clear the selected node const CNodeCombinedStats *stats = clientModel->getPeerTableModel()->getNodeStats(detailNodeRow); if (stats) { - m_node.ban(stats->nodeStats.addr, BanReasonManuallyAdded, bantime); + m_node.ban(stats->nodeStats.addr, bantime); m_node.disconnectByAddress(stats->nodeStats.addr); } } diff --git a/src/qt/rpcconsole.h b/src/qt/rpcconsole.h index de8e37cca2..280c5bd71a 100644 --- a/src/qt/rpcconsole.h +++ b/src/qt/rpcconsole.h @@ -46,7 +46,7 @@ public: return RPCParseCommandLine(&node, strResult, strCommand, true, pstrFilteredOut, wallet_model); } - void setClientModel(ClientModel *model); + void setClientModel(ClientModel *model = nullptr, int bestblock_height = 0, int64_t bestblock_date = 0, double verification_progress = 0.0); void addWallet(WalletModel * const walletModel); void removeWallet(WalletModel* const walletModel); diff --git a/src/qt/sendcoinsdialog.cpp b/src/qt/sendcoinsdialog.cpp index 9e23fe78d8..97fb88d71c 100644 --- a/src/qt/sendcoinsdialog.cpp +++ b/src/qt/sendcoinsdialog.cpp @@ -21,9 +21,9 @@ #include <chainparams.h> #include <interfaces/node.h> #include <key_io.h> +#include <node/ui_interface.h> #include <policy/fees.h> #include <txmempool.h> -#include <ui_interface.h> #include <wallet/coincontrol.h> #include <wallet/fees.h> #include <wallet/wallet.h> @@ -392,7 +392,7 @@ void SendCoinsDialog::on_sendButton_clicked() CMutableTransaction mtx = CMutableTransaction{*(m_current_transaction->getWtx())}; PartiallySignedTransaction psbtx(mtx); bool complete = false; - const TransactionError err = model->wallet().fillPSBT(SIGHASH_ALL, false /* sign */, true /* bip32derivs */, psbtx, complete); + const TransactionError err = model->wallet().fillPSBT(SIGHASH_ALL, false /* sign */, true /* bip32derivs */, psbtx, complete, nullptr); assert(!complete); assert(err == TransactionError::OK); // Serialize the PSBT diff --git a/src/qt/splashscreen.cpp b/src/qt/splashscreen.cpp index ced6a299d5..6e6b2b8466 100644 --- a/src/qt/splashscreen.cpp +++ b/src/qt/splashscreen.cpp @@ -14,7 +14,6 @@ #include <interfaces/wallet.h> #include <qt/guiutil.h> #include <qt/networkstyle.h> -#include <ui_interface.h> #include <util/system.h> #include <util/translation.h> diff --git a/src/qt/test/addressbooktests.cpp b/src/qt/test/addressbooktests.cpp index 476128520c..035c8196bc 100644 --- a/src/qt/test/addressbooktests.cpp +++ b/src/qt/test/addressbooktests.cpp @@ -18,6 +18,7 @@ #include <key.h> #include <key_io.h> #include <wallet/wallet.h> +#include <walletinitinterface.h> #include <QApplication> #include <QTimer> @@ -59,7 +60,8 @@ void EditAddressAndSubmit( void TestAddAddressesToSendBook(interfaces::Node& node) { TestChain100Setup test; - std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(node.context()->chain.get(), WalletLocation(), WalletDatabase::CreateMock()); + node.setContext(&test.m_node); + std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(node.context()->chain.get(), WalletLocation(), CreateMockWalletDatabase()); wallet->SetupLegacyScriptPubKeyMan(); bool firstRun; wallet->LoadWallet(firstRun); diff --git a/src/qt/test/apptests.cpp b/src/qt/test/apptests.cpp index f88d57c716..0b5c341548 100644 --- a/src/qt/test/apptests.cpp +++ b/src/qt/test/apptests.cpp @@ -62,10 +62,12 @@ void AppTests::appTests() } #endif - BasicTestingSetup test{CBaseChainParams::REGTEST}; // Create a temp data directory to backup the gui settings to - ECC_Stop(); // Already started by the common test setup, so stop it to avoid interference - LogInstance().DisconnectTestLogger(); + fs::create_directories([] { + BasicTestingSetup test{CBaseChainParams::REGTEST}; // Create a temp data directory to backup the gui settings to + return GetDataDir() / "blocks"; + }()); + qRegisterMetaType<interfaces::BlockAndHeaderTipInfo>("interfaces::BlockAndHeaderTipInfo"); m_app.parameterSetup(); m_app.createOptionsModel(true /* reset settings */); QScopedPointer<const NetworkStyle> style(NetworkStyle::instantiate(Params().NetworkIDString())); @@ -80,8 +82,9 @@ void AppTests::appTests() m_app.exec(); // Reset global state to avoid interfering with later tests. + LogInstance().DisconnectTestLogger(); AbortShutdown(); - UnloadBlockIndex(); + UnloadBlockIndex(/* mempool */ nullptr); WITH_LOCK(::cs_main, g_chainman.Reset()); } diff --git a/src/qt/test/test_main.cpp b/src/qt/test/test_main.cpp index aefdcd2716..031913bd02 100644 --- a/src/qt/test/test_main.cpp +++ b/src/qt/test/test_main.cpp @@ -40,7 +40,7 @@ Q_IMPORT_PLUGIN(QCocoaIntegrationPlugin); const std::function<void(const std::string&)> G_TEST_LOG_FUN{}; // This is all you need to run all the tests -int main(int argc, char *argv[]) +int main(int argc, char* argv[]) { // Initialize persistent globals with the testing setup state for sanity. // E.g. -datadir in gArgs is set to a temp directory dummy value (instead @@ -52,7 +52,8 @@ int main(int argc, char *argv[]) BasicTestingSetup dummy{CBaseChainParams::REGTEST}; } - std::unique_ptr<interfaces::Node> node = interfaces::MakeNode(); + NodeContext node_context; + std::unique_ptr<interfaces::Node> node = interfaces::MakeNode(&node_context); bool fInvalid = false; @@ -70,6 +71,8 @@ int main(int argc, char *argv[]) BitcoinApplication app(*node); app.setApplicationName("Bitcoin-Qt-test"); + node->setupServerArgs(); // Make gArgs available in the NodeContext + node->context()->args->ClearArgs(); // Clear added args again AppTests app_tests(app); if (QTest::qExec(&app_tests) != 0) { fInvalid = true; diff --git a/src/qt/test/wallettests.cpp b/src/qt/test/wallettests.cpp index 8da0250e57..475fd589af 100644 --- a/src/qt/test/wallettests.cpp +++ b/src/qt/test/wallettests.cpp @@ -138,9 +138,8 @@ void TestGUI(interfaces::Node& node) for (int i = 0; i < 5; ++i) { test.CreateAndProcessBlock({}, GetScriptForRawPubKey(test.coinbaseKey.GetPubKey())); } - node.context()->connman = std::move(test.m_node.connman); - node.context()->mempool = std::move(test.m_node.mempool); - std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(node.context()->chain.get(), WalletLocation(), WalletDatabase::CreateMock()); + node.setContext(&test.m_node); + std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(node.context()->chain.get(), WalletLocation(), CreateMockWalletDatabase()); bool firstRun; wallet->LoadWallet(firstRun); { @@ -178,7 +177,7 @@ void TestGUI(interfaces::Node& node) QString balanceText = balanceLabel->text(); int unit = walletModel.getOptionsModel()->getDisplayUnit(); CAmount balance = walletModel.wallet().getBalance(); - QString balanceComparison = BitcoinUnits::formatWithUnit(unit, balance, false, BitcoinUnits::separatorAlways); + QString balanceComparison = BitcoinUnits::formatWithUnit(unit, balance, false, BitcoinUnits::SeparatorStyle::ALWAYS); QCOMPARE(balanceText, balanceComparison); } @@ -204,7 +203,7 @@ void TestGUI(interfaces::Node& node) QString balanceText = balanceLabel->text().trimmed(); int unit = walletModel.getOptionsModel()->getDisplayUnit(); CAmount balance = walletModel.wallet().getBalance(); - QString balanceComparison = BitcoinUnits::formatWithUnit(unit, balance, false, BitcoinUnits::separatorAlways); + QString balanceComparison = BitcoinUnits::formatWithUnit(unit, balance, false, BitcoinUnits::SeparatorStyle::ALWAYS); QCOMPARE(balanceText, balanceComparison); // Check Request Payment button diff --git a/src/qt/transactionrecord.cpp b/src/qt/transactionrecord.cpp index 01dff8069c..632a18de5c 100644 --- a/src/qt/transactionrecord.cpp +++ b/src/qt/transactionrecord.cpp @@ -47,7 +47,6 @@ QList<TransactionRecord> TransactionRecord::decomposeTransaction(const interface if(mine) { TransactionRecord sub(hash, nTime); - CTxDestination address; sub.idx = i; // vout index sub.credit = txout.nValue; sub.involvesWatchAddress = mine & ISMINE_WATCH_ONLY; @@ -235,6 +234,7 @@ void TransactionRecord::updateStatus(const interfaces::WalletTxStatus& wtx, cons bool TransactionRecord::statusUpdateNeeded(const uint256& block_hash) const { + assert(!block_hash.IsNull()); return status.m_cur_block_hash != block_hash || status.needsUpdate; } diff --git a/src/qt/transactiontablemodel.cpp b/src/qt/transactiontablemodel.cpp index 04eb1ae706..c560dc58e7 100644 --- a/src/qt/transactiontablemodel.cpp +++ b/src/qt/transactiontablemodel.cpp @@ -178,21 +178,16 @@ public: TransactionRecord* index(interfaces::Wallet& wallet, const uint256& cur_block_hash, const int idx) { - if(idx >= 0 && idx < cachedWallet.size()) - { + if (idx >= 0 && idx < cachedWallet.size()) { TransactionRecord *rec = &cachedWallet[idx]; - // Get required locks upfront. This avoids the GUI from getting - // stuck if the core is holding the locks for a longer time - for - // example, during a wallet rescan. - // // If a status update is needed (blocks came in since last check), - // update the status of this transaction from the wallet. Otherwise, - // simply re-use the cached status. + // try to update the status of this transaction from the wallet. + // Otherwise, simply re-use the cached status. interfaces::WalletTxStatus wtx; int numBlocks; int64_t block_time; - if (rec->statusUpdateNeeded(cur_block_hash) && wallet.tryGetTxStatus(rec->hash, wtx, numBlocks, block_time)) { + if (!cur_block_hash.IsNull() && rec->statusUpdateNeeded(cur_block_hash) && wallet.tryGetTxStatus(rec->hash, wtx, numBlocks, block_time)) { rec->updateStatus(wtx, cur_block_hash, numBlocks, block_time); } return rec; @@ -524,7 +519,7 @@ QVariant TransactionTableModel::data(const QModelIndex &index, int role) const case ToAddress: return formatTxToAddress(rec, false); case Amount: - return formatTxAmount(rec, true, BitcoinUnits::separatorAlways); + return formatTxAmount(rec, true, BitcoinUnits::SeparatorStyle::ALWAYS); } break; case Qt::EditRole: @@ -614,14 +609,14 @@ QVariant TransactionTableModel::data(const QModelIndex &index, int role) const details.append(QString::fromStdString(rec->address)); details.append(" "); } - details.append(formatTxAmount(rec, false, BitcoinUnits::separatorNever)); + details.append(formatTxAmount(rec, false, BitcoinUnits::SeparatorStyle::NEVER)); return details; } case ConfirmedRole: return rec->status.status == TransactionStatus::Status::Confirming || rec->status.status == TransactionStatus::Status::Confirmed; case FormattedAmountRole: // Used for copy/export, so don't include separators - return formatTxAmount(rec, false, BitcoinUnits::separatorNever); + return formatTxAmount(rec, false, BitcoinUnits::SeparatorStyle::NEVER); case StatusRole: return rec->status.status; } @@ -664,7 +659,7 @@ QVariant TransactionTableModel::headerData(int section, Qt::Orientation orientat QModelIndex TransactionTableModel::index(int row, int column, const QModelIndex &parent) const { Q_UNUSED(parent); - TransactionRecord* data = priv->index(walletModel->wallet(), walletModel->clientModel().getBestBlockHash(), row); + TransactionRecord* data = priv->index(walletModel->wallet(), walletModel->getLastBlockProcessed(), row); if(data) { return createIndex(row, column, data); diff --git a/src/qt/transactiontablemodel.h b/src/qt/transactiontablemodel.h index f06f0ea15f..4b699d4d7d 100644 --- a/src/qt/transactiontablemodel.h +++ b/src/qt/transactiontablemodel.h @@ -101,7 +101,7 @@ private: QString formatTxDate(const TransactionRecord *wtx) const; QString formatTxType(const TransactionRecord *wtx) const; QString formatTxToAddress(const TransactionRecord *wtx, bool tooltip) const; - QString formatTxAmount(const TransactionRecord *wtx, bool showUnconfirmed=true, BitcoinUnits::SeparatorStyle separators=BitcoinUnits::separatorStandard) const; + QString formatTxAmount(const TransactionRecord *wtx, bool showUnconfirmed=true, BitcoinUnits::SeparatorStyle separators=BitcoinUnits::SeparatorStyle::STANDARD) const; QString formatTooltip(const TransactionRecord *rec) const; QVariant txStatusDecoration(const TransactionRecord *wtx) const; QVariant txWatchonlyDecoration(const TransactionRecord *wtx) const; diff --git a/src/qt/transactionview.cpp b/src/qt/transactionview.cpp index 3df81807f0..54ecfc38ec 100644 --- a/src/qt/transactionview.cpp +++ b/src/qt/transactionview.cpp @@ -17,7 +17,7 @@ #include <qt/transactiontablemodel.h> #include <qt/walletmodel.h> -#include <ui_interface.h> +#include <node/ui_interface.h> #include <QApplication> #include <QComboBox> diff --git a/src/qt/walletcontroller.cpp b/src/qt/walletcontroller.cpp index 20f2ef5b5f..3aed98e0e8 100644 --- a/src/qt/walletcontroller.cpp +++ b/src/qt/walletcontroller.cpp @@ -92,6 +92,23 @@ void WalletController::closeWallet(WalletModel* wallet_model, QWidget* parent) removeAndDeleteWallet(wallet_model); } +void WalletController::closeAllWallets(QWidget* parent) +{ + QMessageBox::StandardButton button = QMessageBox::question(parent, tr("Close all wallets"), + tr("Are you sure you wish to close all wallets?"), + QMessageBox::Yes|QMessageBox::Cancel, + QMessageBox::Yes); + if (button != QMessageBox::Yes) return; + + QMutexLocker locker(&m_mutex); + for (WalletModel* wallet_model : m_wallets) { + wallet_model->wallet().remove(); + Q_EMIT walletRemoved(wallet_model); + delete wallet_model; + } + m_wallets.clear(); +} + WalletModel* WalletController::getOrCreateWallet(std::unique_ptr<interfaces::Wallet> wallet) { QMutexLocker locker(&m_mutex); @@ -245,7 +262,7 @@ void CreateWalletActivity::finish() { destroyProgressDialog(); - if (!m_error_message.original.empty()) { + if (!m_error_message.empty()) { QMessageBox::critical(m_parent_widget, tr("Create wallet failed"), QString::fromStdString(m_error_message.translated)); } else if (!m_warning_message.empty()) { QMessageBox::warning(m_parent_widget, tr("Create wallet warning"), QString::fromStdString(Join(m_warning_message, Untranslated("\n")).translated)); @@ -286,7 +303,7 @@ void OpenWalletActivity::finish() { destroyProgressDialog(); - if (!m_error_message.original.empty()) { + if (!m_error_message.empty()) { QMessageBox::critical(m_parent_widget, tr("Open wallet failed"), QString::fromStdString(m_error_message.translated)); } else if (!m_warning_message.empty()) { QMessageBox::warning(m_parent_widget, tr("Open wallet warning"), QString::fromStdString(Join(m_warning_message, Untranslated("\n")).translated)); diff --git a/src/qt/walletcontroller.h b/src/qt/walletcontroller.h index 24dd83adf7..f7e366878d 100644 --- a/src/qt/walletcontroller.h +++ b/src/qt/walletcontroller.h @@ -62,6 +62,7 @@ public: std::map<std::string, bool> listWalletDir() const; void closeWallet(WalletModel* wallet_model, QWidget* parent = nullptr); + void closeAllWallets(QWidget* parent = nullptr); Q_SIGNALS: void walletAdded(WalletModel* wallet_model); diff --git a/src/qt/walletframe.cpp b/src/qt/walletframe.cpp index 5e68ee4f93..ec56f2755f 100644 --- a/src/qt/walletframe.cpp +++ b/src/qt/walletframe.cpp @@ -165,11 +165,11 @@ void WalletFrame::gotoVerifyMessageTab(QString addr) walletView->gotoVerifyMessageTab(addr); } -void WalletFrame::gotoLoadPSBT() +void WalletFrame::gotoLoadPSBT(bool from_clipboard) { WalletView *walletView = currentWalletView(); if (walletView) { - walletView->gotoLoadPSBT(); + walletView->gotoLoadPSBT(from_clipboard); } } diff --git a/src/qt/walletframe.h b/src/qt/walletframe.h index d90ade5005..2b5f263468 100644 --- a/src/qt/walletframe.h +++ b/src/qt/walletframe.h @@ -79,7 +79,7 @@ public Q_SLOTS: void gotoVerifyMessageTab(QString addr = ""); /** Load Partially Signed Bitcoin Transaction */ - void gotoLoadPSBT(); + void gotoLoadPSBT(bool from_clipboard = false); /** Encrypt the wallet */ void encryptWallet(bool status); diff --git a/src/qt/walletmodel.cpp b/src/qt/walletmodel.cpp index 671b5e1ce6..e374dd191c 100644 --- a/src/qt/walletmodel.cpp +++ b/src/qt/walletmodel.cpp @@ -21,8 +21,8 @@ #include <interfaces/handler.h> #include <interfaces/node.h> #include <key_io.h> +#include <node/ui_interface.h> #include <psbt.h> -#include <ui_interface.h> #include <util/system.h> // for GetBoolArg #include <util/translation.h> #include <wallet/coincontrol.h> @@ -87,7 +87,7 @@ void WalletModel::pollBalanceChanged() { // Avoid recomputing wallet balances unless a TransactionChanged or // BlockTip notification was received. - if (!fForceCheckBalanceChanged && m_cached_last_update_tip == m_client_model->getBestBlockHash()) return; + if (!fForceCheckBalanceChanged && m_cached_last_update_tip == getLastBlockProcessed()) return; // Try to get balances and return early if locks can't be acquired. This // avoids the GUI from getting stuck on periodical polls if the core is @@ -536,7 +536,7 @@ bool WalletModel::bumpFee(uint256 hash, uint256& new_hash) if (create_psbt) { PartiallySignedTransaction psbtx(mtx); bool complete = false; - const TransactionError err = wallet().fillPSBT(SIGHASH_ALL, false /* sign */, true /* bip32derivs */, psbtx, complete); + const TransactionError err = wallet().fillPSBT(SIGHASH_ALL, false /* sign */, true /* bip32derivs */, psbtx, complete, nullptr); if (err != TransactionError::OK || complete) { QMessageBox::critical(nullptr, tr("Fee bump error"), tr("Can't draft transaction.")); return false; @@ -588,3 +588,8 @@ void WalletModel::refresh(bool pk_hash_only) { addressTableModel = new AddressTableModel(this, pk_hash_only); } + +uint256 WalletModel::getLastBlockProcessed() const +{ + return m_client_model ? m_client_model->getBestBlockHash() : uint256{}; +} diff --git a/src/qt/walletmodel.h b/src/qt/walletmodel.h index 38e8a14556..fd52db2da3 100644 --- a/src/qt/walletmodel.h +++ b/src/qt/walletmodel.h @@ -155,6 +155,9 @@ public: AddressTableModel* getAddressTableModel() const { return addressTableModel; } void refresh(bool pk_hash_only = false); + + uint256 getLastBlockProcessed() const; + private: std::unique_ptr<interfaces::Wallet> m_wallet; std::unique_ptr<interfaces::Handler> m_handler_unload; diff --git a/src/qt/walletview.cpp b/src/qt/walletview.cpp index 66fbf978be..2fc883a5f5 100644 --- a/src/qt/walletview.cpp +++ b/src/qt/walletview.cpp @@ -4,13 +4,11 @@ #include <qt/walletview.h> -#include <node/psbt.h> -#include <node/transaction.h> -#include <policy/policy.h> #include <qt/addressbookpage.h> #include <qt/askpassphrasedialog.h> #include <qt/clientmodel.h> #include <qt/guiutil.h> +#include <qt/psbtoperationsdialog.h> #include <qt/optionsmodel.h> #include <qt/overviewpage.h> #include <qt/platformstyle.h> @@ -22,11 +20,14 @@ #include <qt/walletmodel.h> #include <interfaces/node.h> -#include <ui_interface.h> +#include <node/ui_interface.h> +#include <psbt.h> #include <util/strencodings.h> #include <QAction> #include <QActionGroup> +#include <QApplication> +#include <QClipboard> #include <QFileDialog> #include <QHBoxLayout> #include <QProgressDialog> @@ -204,78 +205,42 @@ void WalletView::gotoVerifyMessageTab(QString addr) signVerifyMessageDialog->setAddress_VM(addr); } -void WalletView::gotoLoadPSBT() +void WalletView::gotoLoadPSBT(bool from_clipboard) { - QString filename = GUIUtil::getOpenFileName(this, - tr("Load Transaction Data"), QString(), - tr("Partially Signed Transaction (*.psbt)"), nullptr); - if (filename.isEmpty()) return; - if (GetFileSize(filename.toLocal8Bit().data(), MAX_FILE_SIZE_PSBT) == MAX_FILE_SIZE_PSBT) { - Q_EMIT message(tr("Error"), tr("PSBT file must be smaller than 100 MiB"), CClientUIInterface::MSG_ERROR); - return; + std::string data; + + if (from_clipboard) { + std::string raw = QApplication::clipboard()->text().toStdString(); + bool invalid; + data = DecodeBase64(raw, &invalid); + if (invalid) { + Q_EMIT message(tr("Error"), tr("Unable to decode PSBT from clipboard (invalid base64)"), CClientUIInterface::MSG_ERROR); + return; + } + } else { + QString filename = GUIUtil::getOpenFileName(this, + tr("Load Transaction Data"), QString(), + tr("Partially Signed Transaction (*.psbt)"), nullptr); + if (filename.isEmpty()) return; + if (GetFileSize(filename.toLocal8Bit().data(), MAX_FILE_SIZE_PSBT) == MAX_FILE_SIZE_PSBT) { + Q_EMIT message(tr("Error"), tr("PSBT file must be smaller than 100 MiB"), CClientUIInterface::MSG_ERROR); + return; + } + std::ifstream in(filename.toLocal8Bit().data(), std::ios::binary); + data = std::string(std::istreambuf_iterator<char>{in}, {}); } - std::ifstream in(filename.toLocal8Bit().data(), std::ios::binary); - std::string data(std::istreambuf_iterator<char>{in}, {}); std::string error; PartiallySignedTransaction psbtx; if (!DecodeRawPSBT(psbtx, data, error)) { - Q_EMIT message(tr("Error"), tr("Unable to decode PSBT file") + "\n" + QString::fromStdString(error), CClientUIInterface::MSG_ERROR); + Q_EMIT message(tr("Error"), tr("Unable to decode PSBT") + "\n" + QString::fromStdString(error), CClientUIInterface::MSG_ERROR); return; } - CMutableTransaction mtx; - bool complete = false; - PSBTAnalysis analysis = AnalyzePSBT(psbtx); - QMessageBox msgBox; - msgBox.setText("PSBT"); - switch (analysis.next) { - case PSBTRole::CREATOR: - case PSBTRole::UPDATER: - msgBox.setInformativeText("PSBT is incomplete. Copy to clipboard for manual inspection?"); - break; - case PSBTRole::SIGNER: - msgBox.setInformativeText("Transaction needs more signatures. Copy to clipboard?"); - break; - case PSBTRole::FINALIZER: - case PSBTRole::EXTRACTOR: - complete = FinalizeAndExtractPSBT(psbtx, mtx); - if (complete) { - msgBox.setInformativeText(tr("Would you like to send this transaction?")); - } else { - // The analyzer missed something, e.g. if there are final_scriptSig/final_scriptWitness - // but with invalid signatures. - msgBox.setInformativeText(tr("There was an unexpected problem processing the PSBT. Copy to clipboard for manual inspection?")); - } - } - - msgBox.setStandardButtons(QMessageBox::Yes | QMessageBox::Cancel); - switch (msgBox.exec()) { - case QMessageBox::Yes: { - if (complete) { - std::string err_string; - CTransactionRef tx = MakeTransactionRef(mtx); - - TransactionError result = BroadcastTransaction(*clientModel->node().context(), tx, err_string, DEFAULT_MAX_RAW_TX_FEE_RATE.GetFeePerK(), /* relay */ true, /* wait_callback */ false); - if (result == TransactionError::OK) { - Q_EMIT message(tr("Success"), tr("Broadcasted transaction sucessfully."), CClientUIInterface::MSG_INFORMATION | CClientUIInterface::MODAL); - } else { - Q_EMIT message(tr("Error"), QString::fromStdString(err_string), CClientUIInterface::MSG_ERROR); - } - } else { - // Serialize the PSBT - CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION); - ssTx << psbtx; - GUIUtil::setClipboard(EncodeBase64(ssTx.str()).c_str()); - Q_EMIT message(tr("PSBT copied"), "Copied to clipboard", CClientUIInterface::MSG_INFORMATION); - return; - } - } - case QMessageBox::Cancel: - break; - default: - assert(false); - } + PSBTOperationsDialog* dlg = new PSBTOperationsDialog(this, walletModel, clientModel); + dlg->openWithPSBT(psbtx); + dlg->setAttribute(Qt::WA_DeleteOnClose); + dlg->exec(); } bool WalletView::handlePaymentRequest(const SendCoinsRecipient& recipient) diff --git a/src/qt/walletview.h b/src/qt/walletview.h index fd09456baa..f186554758 100644 --- a/src/qt/walletview.h +++ b/src/qt/walletview.h @@ -84,7 +84,7 @@ public Q_SLOTS: /** Show Sign/Verify Message dialog and switch to verify message tab */ void gotoVerifyMessageTab(QString addr = ""); /** Load Partially Signed Bitcoin Transaction */ - void gotoLoadPSBT(); + void gotoLoadPSBT(bool from_clipboard = false); /** Show incoming transaction notification for new transactions. diff --git a/src/random.cpp b/src/random.cpp index 9c9a35709a..af9504e0ce 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -315,12 +315,16 @@ void GetOSRand(unsigned char *ent32) if (getentropy(ent32, NUM_OS_RANDOM_BYTES) != 0) { RandFailure(); } + // Silence a compiler warning about unused function. + (void)GetDevURandom; #elif defined(HAVE_GETENTROPY_RAND) && defined(MAC_OSX) /* getentropy() is available on macOS 10.12 and later. */ if (getentropy(ent32, NUM_OS_RANDOM_BYTES) != 0) { RandFailure(); } + // Silence a compiler warning about unused function. + (void)GetDevURandom; #elif defined(HAVE_SYSCTL_ARND) /* FreeBSD, NetBSD and similar. It is possible for the call to return less * bytes than requested, so need to read in a loop. @@ -334,6 +338,8 @@ void GetOSRand(unsigned char *ent32) } have += len; } while (have < NUM_OS_RANDOM_BYTES); + // Silence a compiler warning about unused function. + (void)GetDevURandom; #else /* Fall back to /dev/urandom if there is no specific method implemented to * get system entropy for this OS. diff --git a/src/rest.cpp b/src/rest.cpp index cde8b472d3..7130625d5c 100644 --- a/src/rest.cpp +++ b/src/rest.cpp @@ -68,13 +68,32 @@ static bool RESTERR(HTTPRequest* req, enum HTTPStatusCode status, std::string me } /** - * Get the node context mempool. + * Get the node context. * - * Set the HTTP error and return nullptr if node context - * mempool is not found. + * @param[in] req The HTTP request, whose status code will be set if node + * context is not found. + * @returns Pointer to the node context or nullptr if not found. + */ +static NodeContext* GetNodeContext(const util::Ref& context, HTTPRequest* req) +{ + NodeContext* node = context.Has<NodeContext>() ? &context.Get<NodeContext>() : nullptr; + if (!node) { + RESTERR(req, HTTP_INTERNAL_SERVER_ERROR, + strprintf("%s:%d (%s)\n" + "Internal bug detected: Node context not found!\n" + "You may report this issue here: %s\n", + __FILE__, __LINE__, __func__, PACKAGE_BUGREPORT)); + return nullptr; + } + return node; +} + +/** + * Get the node context mempool. * - * @param[in] req the HTTP request - * return pointer to the mempool or nullptr if no mempool found + * @param[in] req The HTTP request, whose status code will be set if node + * context mempool is not found. + * @returns Pointer to the mempool or nullptr if no mempool found. */ static CTxMemPool* GetMemPool(const util::Ref& context, HTTPRequest* req) { @@ -188,7 +207,7 @@ static bool rest_headers(const util::Ref& context, ssHeader << pindex->GetBlockHeader(); } - std::string strHex = HexStr(ssHeader.begin(), ssHeader.end()) + "\n"; + std::string strHex = HexStr(ssHeader) + "\n"; req->WriteHeader("Content-Type", "text/plain"); req->WriteReply(HTTP_OK, strHex); return true; @@ -253,7 +272,7 @@ static bool rest_block(HTTPRequest* req, case RetFormat::HEX: { CDataStream ssBlock(SER_NETWORK, PROTOCOL_VERSION | RPCSerializationFlags()); ssBlock << block; - std::string strHex = HexStr(ssBlock.begin(), ssBlock.end()) + "\n"; + std::string strHex = HexStr(ssBlock) + "\n"; req->WriteHeader("Content-Type", "text/plain"); req->WriteReply(HTTP_OK, strHex); return true; @@ -371,10 +390,13 @@ static bool rest_tx(const util::Ref& context, HTTPRequest* req, const std::strin g_txindex->BlockUntilSyncedToCurrentChain(); } - CTransactionRef tx; + const NodeContext* const node = GetNodeContext(context, req); + if (!node) return false; uint256 hashBlock = uint256(); - if (!GetTransaction(hash, tx, Params().GetConsensus(), hashBlock)) + const CTransactionRef tx = GetTransaction(/* block_index */ nullptr, node->mempool, hash, Params().GetConsensus(), hashBlock); + if (!tx) { return RESTERR(req, HTTP_NOT_FOUND, hashStr + " not found"); + } switch (rf) { case RetFormat::BINARY: { @@ -391,7 +413,7 @@ static bool rest_tx(const util::Ref& context, HTTPRequest* req, const std::strin CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION | RPCSerializationFlags()); ssTx << tx; - std::string strHex = HexStr(ssTx.begin(), ssTx.end()) + "\n"; + std::string strHex = HexStr(ssTx) + "\n"; req->WriteHeader("Content-Type", "text/plain"); req->WriteReply(HTTP_OK, strHex); return true; @@ -556,7 +578,7 @@ static bool rest_getutxos(const util::Ref& context, HTTPRequest* req, const std: case RetFormat::HEX: { CDataStream ssGetUTXOResponse(SER_NETWORK, PROTOCOL_VERSION); ssGetUTXOResponse << ::ChainActive().Height() << ::ChainActive().Tip()->GetBlockHash() << bitmap << outs; - std::string strHex = HexStr(ssGetUTXOResponse.begin(), ssGetUTXOResponse.end()) + "\n"; + std::string strHex = HexStr(ssGetUTXOResponse) + "\n"; req->WriteHeader("Content-Type", "text/plain"); req->WriteReply(HTTP_OK, strHex); diff --git a/src/rpc/blockchain.cpp b/src/rpc/blockchain.cpp index b0936cef5a..f27373b57c 100644 --- a/src/rpc/blockchain.cpp +++ b/src/rpc/blockchain.cpp @@ -32,6 +32,7 @@ #include <util/ref.h> #include <util/strencodings.h> #include <util/system.h> +#include <util/translation.h> #include <validation.h> #include <validationinterface.h> #include <warnings.h> @@ -74,7 +75,10 @@ CTxMemPool& EnsureMemPool(const util::Ref& context) ChainstateManager& EnsureChainman(const util::Ref& context) { NodeContext& node = EnsureNodeContext(context); - return EnsureChainman(node); + if (!node.chainman) { + throw JSONRPCError(RPC_INTERNAL_ERROR, "Node chainman not found"); + } + return *node.chainman; } /* Calculate the difficulty for a given block index. @@ -521,9 +525,9 @@ static UniValue getrawmempool(const JSONRPCRequest& request) {RPCResult::Type::STR_HEX, "", "The transaction id"}, }}, RPCResult{"for verbose = true", - RPCResult::Type::OBJ, "", "", + RPCResult::Type::OBJ_DYN, "", "", { - {RPCResult::Type::OBJ_DYN, "transactionid", "", MempoolEntryDescription()}, + {RPCResult::Type::OBJ, "transactionid", "", MempoolEntryDescription()}, }}, }, RPCExamples{ @@ -552,7 +556,7 @@ static UniValue getmempoolancestors(const JSONRPCRequest& request) RPCResult::Type::ARR, "", "", {{RPCResult::Type::STR_HEX, "", "The transaction id of an in-mempool ancestor transaction"}}}, RPCResult{"for verbose = true", - RPCResult::Type::OBJ_DYN, "transactionid", "", MempoolEntryDescription()}, + RPCResult::Type::OBJ, "transactionid", "", MempoolEntryDescription()}, }, RPCExamples{ HelpExampleCli("getmempoolancestors", "\"mytxid\"") @@ -612,9 +616,9 @@ static UniValue getmempooldescendants(const JSONRPCRequest& request) RPCResult::Type::ARR, "", "", {{RPCResult::Type::STR_HEX, "", "The transaction id of an in-mempool descendant transaction"}}}, RPCResult{"for verbose = true", - RPCResult::Type::OBJ, "", "", + RPCResult::Type::OBJ_DYN, "", "", { - {RPCResult::Type::OBJ_DYN, "transactionid", "", MempoolEntryDescription()}, + {RPCResult::Type::OBJ, "transactionid", "", MempoolEntryDescription()}, }}, }, RPCExamples{ @@ -670,7 +674,7 @@ static UniValue getmempoolentry(const JSONRPCRequest& request) {"txid", RPCArg::Type::STR_HEX, RPCArg::Optional::NO, "The transaction id (must be in mempool)"}, }, RPCResult{ - RPCResult::Type::OBJ_DYN, "", "", MempoolEntryDescription()}, + RPCResult::Type::OBJ, "", "", MempoolEntryDescription()}, RPCExamples{ HelpExampleCli("getmempoolentry", "\"mytxid\"") + HelpExampleRpc("getmempoolentry", "\"mytxid\"") @@ -778,7 +782,7 @@ static UniValue getblockheader(const JSONRPCRequest& request) { CDataStream ssBlock(SER_NETWORK, PROTOCOL_VERSION); ssBlock << pblockindex->GetBlockHeader(); - std::string strHex = HexStr(ssBlock.begin(), ssBlock.end()); + std::string strHex = HexStr(ssBlock); return strHex; } @@ -794,10 +798,8 @@ static CBlock GetBlockChecked(const CBlockIndex* pblockindex) if (!ReadBlockFromDisk(block, pblockindex, Params().GetConsensus())) { // Block not found on disk. This could be because we have the block - // header in our index but don't have the block (for example if a - // non-whitelisted node sends us an unrequested long chain of valid - // blocks, we add the headers to our index, but don't accept the - // block). + // header in our index but not yet have the block or did not accept the + // block. throw JSONRPCError(RPC_MISC_ERROR, "Block not found on disk"); } @@ -904,7 +906,7 @@ static UniValue getblock(const JSONRPCRequest& request) { CDataStream ssBlock(SER_NETWORK, PROTOCOL_VERSION | RPCSerializationFlags()); ssBlock << block; - std::string strHex = HexStr(ssBlock.begin(), ssBlock.end()); + std::string strHex = HexStr(ssBlock); return strHex; } @@ -971,7 +973,9 @@ static UniValue gettxoutsetinfo(const JSONRPCRequest& request) RPCHelpMan{"gettxoutsetinfo", "\nReturns statistics about the unspent transaction output set.\n" "Note this call may take some time.\n", - {}, + { + {"hash_type", RPCArg::Type::STR, /* default */ "hash_serialized_2", "Which UTXO set hash should be calculated. Options: 'hash_serialized_2' (the legacy algorithm), 'none'."}, + }, RPCResult{ RPCResult::Type::OBJ, "", "", { @@ -980,7 +984,7 @@ static UniValue gettxoutsetinfo(const JSONRPCRequest& request) {RPCResult::Type::NUM, "transactions", "The number of transactions with unspent outputs"}, {RPCResult::Type::NUM, "txouts", "The number of unspent transaction outputs"}, {RPCResult::Type::NUM, "bogosize", "A meaningless metric for UTXO set size"}, - {RPCResult::Type::STR_HEX, "hash_serialized_2", "The serialized hash"}, + {RPCResult::Type::STR_HEX, "hash_serialized_2", "The serialized hash (only present if 'hash_serialized_2' hash_type is chosen)"}, {RPCResult::Type::NUM, "disk_size", "The estimated size of the chainstate on disk"}, {RPCResult::Type::STR_AMOUNT, "total_amount", "The total amount"}, }}, @@ -995,14 +999,19 @@ static UniValue gettxoutsetinfo(const JSONRPCRequest& request) CCoinsStats stats; ::ChainstateActive().ForceFlushStateToDisk(); + const CoinStatsHashType hash_type = ParseHashType(request.params[0], CoinStatsHashType::HASH_SERIALIZED); + CCoinsView* coins_view = WITH_LOCK(cs_main, return &ChainstateActive().CoinsDB()); - if (GetUTXOStats(coins_view, stats, RpcInterruptionPoint)) { + NodeContext& node = EnsureNodeContext(request.context); + if (GetUTXOStats(coins_view, stats, hash_type, node.rpc_interruption_point)) { ret.pushKV("height", (int64_t)stats.nHeight); ret.pushKV("bestblock", stats.hashBlock.GetHex()); ret.pushKV("transactions", (int64_t)stats.nTransactions); ret.pushKV("txouts", (int64_t)stats.nTransactionOutputs); ret.pushKV("bogosize", (int64_t)stats.nBogoSize); - ret.pushKV("hash_serialized_2", stats.hashSerialized.GetHex()); + if (hash_type == CoinStatsHashType::HASH_SERIALIZED) { + ret.pushKV("hash_serialized_2", stats.hashSerialized.GetHex()); + } ret.pushKV("disk_size", stats.nDiskSize); ret.pushKV("total_amount", ValueFromAmount(stats.nTotalAmount)); } else { @@ -1095,7 +1104,8 @@ static UniValue verifychain(const JSONRPCRequest& request) RPCHelpMan{"verifychain", "\nVerifies blockchain database.\n", { - {"checklevel", RPCArg::Type::NUM, /* default */ strprintf("%d, range=0-4", DEFAULT_CHECKLEVEL), "How thorough the block verification is."}, + {"checklevel", RPCArg::Type::NUM, /* default */ strprintf("%d, range=0-4", DEFAULT_CHECKLEVEL), + strprintf("How thorough the block verification is:\n - %s", Join(CHECKLEVEL_DOC, "\n- "))}, {"nblocks", RPCArg::Type::NUM, /* default */ strprintf("%d, 0=all", DEFAULT_CHECKBLOCKS), "The number of blocks to check."}, }, RPCResult{ @@ -1277,7 +1287,7 @@ UniValue getblockchaininfo(const JSONRPCRequest& request) BIP9SoftForkDescPushBack(softforks, "testdummy", consensusParams, Consensus::DEPLOYMENT_TESTDUMMY); obj.pushKV("softforks", softforks); - obj.pushKV("warnings", GetWarnings(false)); + obj.pushKV("warnings", GetWarnings(false).original); return obj; } @@ -1323,50 +1333,48 @@ static UniValue getchaintips(const JSONRPCRequest& request) }, }.Check(request); + ChainstateManager& chainman = EnsureChainman(request.context); LOCK(cs_main); /* - * Idea: the set of chain tips is ::ChainActive().tip, plus orphan blocks which do not have another orphan building off of them. + * Idea: The set of chain tips is the active chain tip, plus orphan blocks which do not have another orphan building off of them. * Algorithm: * - Make one pass through BlockIndex(), picking out the orphan blocks, and also storing a set of the orphan block's pprev pointers. * - Iterate through the orphan blocks. If the block isn't pointed to by another orphan, it is a chain tip. - * - add ::ChainActive().Tip() + * - Add the active chain tip */ std::set<const CBlockIndex*, CompareBlocksByHeight> setTips; std::set<const CBlockIndex*> setOrphans; std::set<const CBlockIndex*> setPrevs; - for (const std::pair<const uint256, CBlockIndex*>& item : ::BlockIndex()) - { - if (!::ChainActive().Contains(item.second)) { + for (const std::pair<const uint256, CBlockIndex*>& item : chainman.BlockIndex()) { + if (!chainman.ActiveChain().Contains(item.second)) { setOrphans.insert(item.second); setPrevs.insert(item.second->pprev); } } - for (std::set<const CBlockIndex*>::iterator it = setOrphans.begin(); it != setOrphans.end(); ++it) - { + for (std::set<const CBlockIndex*>::iterator it = setOrphans.begin(); it != setOrphans.end(); ++it) { if (setPrevs.erase(*it) == 0) { setTips.insert(*it); } } // Always report the currently active tip. - setTips.insert(::ChainActive().Tip()); + setTips.insert(chainman.ActiveChain().Tip()); /* Construct the output array. */ UniValue res(UniValue::VARR); - for (const CBlockIndex* block : setTips) - { + for (const CBlockIndex* block : setTips) { UniValue obj(UniValue::VOBJ); obj.pushKV("height", block->nHeight); obj.pushKV("hash", block->phashBlock->GetHex()); - const int branchLen = block->nHeight - ::ChainActive().FindFork(block)->nHeight; + const int branchLen = block->nHeight - chainman.ActiveChain().FindFork(block)->nHeight; obj.pushKV("branchlen", branchLen); std::string status; - if (::ChainActive().Contains(block)) { + if (chainman.ActiveChain().Contains(block)) { // This block is part of the currently active chain. status = "active"; } else if (block->nStatus & BLOCK_FAILED_MASK) { @@ -1965,8 +1973,10 @@ static UniValue savemempool(const JSONRPCRequest& request) return NullUniValue; } +namespace { //! Search for a given set of pubkey scripts -bool FindScriptPubKey(std::atomic<int>& scan_progress, const std::atomic<bool>& should_abort, int64_t& count, CCoinsViewCursor* cursor, const std::set<CScript>& needles, std::map<COutPoint, Coin>& out_results) { +bool FindScriptPubKey(std::atomic<int>& scan_progress, const std::atomic<bool>& should_abort, int64_t& count, CCoinsViewCursor* cursor, const std::set<CScript>& needles, std::map<COutPoint, Coin>& out_results, std::function<void()>& interruption_point) +{ scan_progress = 0; count = 0; while (cursor->Valid()) { @@ -1974,7 +1984,7 @@ bool FindScriptPubKey(std::atomic<int>& scan_progress, const std::atomic<bool>& Coin coin; if (!cursor->GetKey(key) || !cursor->GetValue(coin)) return false; if (++count % 8192 == 0) { - RpcInterruptionPoint(); + interruption_point(); if (should_abort) { // allow to abort the scan via the abort reference return false; @@ -1993,6 +2003,7 @@ bool FindScriptPubKey(std::atomic<int>& scan_progress, const std::atomic<bool>& scan_progress = 100; return true; } +} // namespace /** RAII object to prevent concurrency issue when scanning the txout set */ static std::atomic<int> g_scan_progress; @@ -2141,7 +2152,8 @@ UniValue scantxoutset(const JSONRPCRequest& request) tip = ::ChainActive().Tip(); CHECK_NONFATAL(tip); } - bool res = FindScriptPubKey(g_scan_progress, g_should_abort_scan, count, pcursor.get(), needles, coins); + NodeContext& node = EnsureNodeContext(request.context); + bool res = FindScriptPubKey(g_scan_progress, g_should_abort_scan, count, pcursor.get(), needles, coins, node.rpc_interruption_point); result.pushKV("success", res); result.pushKV("txouts", count); result.pushKV("height", tip->nHeight); @@ -2157,7 +2169,7 @@ UniValue scantxoutset(const JSONRPCRequest& request) UniValue unspent(UniValue::VOBJ); unspent.pushKV("txid", outpoint.hash.GetHex()); unspent.pushKV("vout", (int32_t)outpoint.n); - unspent.pushKV("scriptPubKey", HexStr(txo.scriptPubKey.begin(), txo.scriptPubKey.end())); + unspent.pushKV("scriptPubKey", HexStr(txo.scriptPubKey)); unspent.pushKV("desc", descriptors[txo.scriptPubKey]); unspent.pushKV("amount", ValueFromAmount(txo.nValue)); unspent.pushKV("height", (int32_t)coin.nHeight); @@ -2296,6 +2308,7 @@ UniValue dumptxoutset(const JSONRPCRequest& request) std::unique_ptr<CCoinsViewCursor> pcursor; CCoinsStats stats; CBlockIndex* tip; + NodeContext& node = EnsureNodeContext(request.context); { // We need to lock cs_main to ensure that the coinsdb isn't written to @@ -2314,7 +2327,7 @@ UniValue dumptxoutset(const JSONRPCRequest& request) ::ChainstateActive().ForceFlushStateToDisk(); - if (!GetUTXOStats(&::ChainstateActive().CoinsDB(), stats, RpcInterruptionPoint)) { + if (!GetUTXOStats(&::ChainstateActive().CoinsDB(), stats, CoinStatsHashType::NONE, node.rpc_interruption_point)) { throw JSONRPCError(RPC_INTERNAL_ERROR, "Unable to read UTXO set"); } @@ -2332,7 +2345,7 @@ UniValue dumptxoutset(const JSONRPCRequest& request) unsigned int iter{0}; while (pcursor->Valid()) { - if (iter % 5000 == 0) RpcInterruptionPoint(); + if (iter % 5000 == 0) node.rpc_interruption_point(); ++iter; if (pcursor->GetKey(key) && pcursor->GetValue(coin)) { afile << key; @@ -2375,7 +2388,7 @@ static const CRPCCommand commands[] = { "blockchain", "getmempoolinfo", &getmempoolinfo, {} }, { "blockchain", "getrawmempool", &getrawmempool, {"verbose"} }, { "blockchain", "gettxout", &gettxout, {"txid","n","include_mempool"} }, - { "blockchain", "gettxoutsetinfo", &gettxoutsetinfo, {} }, + { "blockchain", "gettxoutsetinfo", &gettxoutsetinfo, {"hash_type"} }, { "blockchain", "pruneblockchain", &pruneblockchain, {"height"} }, { "blockchain", "savemempool", &savemempool, {} }, { "blockchain", "verifychain", &verifychain, {"checklevel","nblocks"} }, diff --git a/src/rpc/client.cpp b/src/rpc/client.cpp index 3045a74d7a..d61e02aee2 100644 --- a/src/rpc/client.cpp +++ b/src/rpc/client.cpp @@ -151,6 +151,7 @@ static const CRPCConvertParam vRPCConvertParams[] = { "getmempoolancestors", 1, "verbose" }, { "getmempooldescendants", 1, "verbose" }, { "bumpfee", 1, "options" }, + { "psbtbumpfee", 1, "options" }, { "logging", 0, "include" }, { "logging", 1, "exclude" }, { "disconnectnode", 1, "nodeid" }, @@ -173,6 +174,7 @@ static const CRPCConvertParam vRPCConvertParams[] = { "createwallet", 4, "avoid_reuse"}, { "createwallet", 5, "descriptors"}, { "getnodeaddresses", 0, "count"}, + { "addpeeraddress", 1, "port"}, { "stop", 0, "wait" }, }; // clang-format on @@ -217,7 +219,7 @@ UniValue ParseNonRFCJSONValue(const std::string& strVal) UniValue jVal; if (!jVal.read(std::string("[")+strVal+std::string("]")) || !jVal.isArray() || jVal.size()!=1) - throw std::runtime_error(std::string("Error parsing JSON:")+strVal); + throw std::runtime_error(std::string("Error parsing JSON: ") + strVal); return jVal[0]; } diff --git a/src/rpc/mining.cpp b/src/rpc/mining.cpp index 3612f14bbf..fee6a893eb 100644 --- a/src/rpc/mining.cpp +++ b/src/rpc/mining.cpp @@ -17,6 +17,7 @@ #include <policy/fees.h> #include <pow.h> #include <rpc/blockchain.h> +#include <rpc/mining.h> #include <rpc/server.h> #include <rpc/util.h> #include <script/descriptor.h> @@ -29,6 +30,7 @@ #include <util/strencodings.h> #include <util/string.h> #include <util/system.h> +#include <util/translation.h> #include <validation.h> #include <validationinterface.h> #include <versionbitsinfo.h> @@ -179,7 +181,7 @@ static bool getScriptFromDescriptor(const std::string& descriptor, CScript& scri throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, strprintf("Cannot derive script without private keys")); } - // Combo desriptors can have 2 or 4 scripts, so we can't just check scripts.size() == 1 + // Combo descriptors can have 2 or 4 scripts, so we can't just check scripts.size() == 1 CHECK_NONFATAL(scripts.size() > 0 && scripts.size() <= 4); if (scripts.size() == 1) { @@ -206,7 +208,7 @@ static UniValue generatetodescriptor(const JSONRPCRequest& request) { {"num_blocks", RPCArg::Type::NUM, RPCArg::Optional::NO, "How many blocks are generated immediately."}, {"descriptor", RPCArg::Type::STR, RPCArg::Optional::NO, "The descriptor to send the newly generated bitcoin to."}, - {"maxtries", RPCArg::Type::NUM, /* default */ "1000000", "How many iterations to try."}, + {"maxtries", RPCArg::Type::NUM, /* default */ ToString(DEFAULT_MAX_TRIES), "How many iterations to try."}, }, RPCResult{ RPCResult::Type::ARR, "", "hashes of blocks generated", @@ -220,7 +222,7 @@ static UniValue generatetodescriptor(const JSONRPCRequest& request) .Check(request); const int num_blocks{request.params[0].get_int()}; - const int64_t max_tries{request.params[2].isNull() ? 1000000 : request.params[2].get_int()}; + const uint64_t max_tries{request.params[2].isNull() ? DEFAULT_MAX_TRIES : request.params[2].get_int()}; CScript coinbase_script; std::string error; @@ -241,7 +243,7 @@ static UniValue generatetoaddress(const JSONRPCRequest& request) { {"nblocks", RPCArg::Type::NUM, RPCArg::Optional::NO, "How many blocks are generated immediately."}, {"address", RPCArg::Type::STR, RPCArg::Optional::NO, "The address to send the newly generated bitcoin to."}, - {"maxtries", RPCArg::Type::NUM, /* default */ "1000000", "How many iterations to try."}, + {"maxtries", RPCArg::Type::NUM, /* default */ ToString(DEFAULT_MAX_TRIES), "How many iterations to try."}, }, RPCResult{ RPCResult::Type::ARR, "", "hashes of blocks generated", @@ -251,16 +253,13 @@ static UniValue generatetoaddress(const JSONRPCRequest& request) RPCExamples{ "\nGenerate 11 blocks to myaddress\n" + HelpExampleCli("generatetoaddress", "11 \"myaddress\"") - + "If you are running the bitcoin core wallet, you can get a new address to send the newly generated bitcoin to with:\n" + + "If you are using the " PACKAGE_NAME " wallet, you can get a new address to send the newly generated bitcoin to with:\n" + HelpExampleCli("getnewaddress", "") }, }.Check(request); - int nGenerate = request.params[0].get_int(); - uint64_t nMaxTries = 1000000; - if (!request.params[2].isNull()) { - nMaxTries = request.params[2].get_int(); - } + const int num_blocks{request.params[0].get_int()}; + const uint64_t max_tries{request.params[2].isNull() ? DEFAULT_MAX_TRIES : request.params[2].get_int()}; CTxDestination destination = DecodeDestination(request.params[1].get_str()); if (!IsValidDestination(destination)) { @@ -272,7 +271,7 @@ static UniValue generatetoaddress(const JSONRPCRequest& request) CScript coinbase_script = GetScriptForDestination(destination); - return generateBlocks(chainman, mempool, coinbase_script, nGenerate, nMaxTries); + return generateBlocks(chainman, mempool, coinbase_script, num_blocks, max_tries); } static UniValue generateblock(const JSONRPCRequest& request) @@ -370,7 +369,7 @@ static UniValue generateblock(const JSONRPCRequest& request) } uint256 block_hash; - uint64_t max_tries{1000000}; + uint64_t max_tries{DEFAULT_MAX_TRIES}; unsigned int extra_nonce{0}; if (!GenerateBlock(EnsureChainman(request.context), block, max_tries, extra_nonce, block_hash) || block_hash.IsNull()) { @@ -416,7 +415,7 @@ static UniValue getmininginfo(const JSONRPCRequest& request) obj.pushKV("networkhashps", getnetworkhashps(request)); obj.pushKV("pooledtx", (uint64_t)mempool.size()); obj.pushKV("chain", Params().NetworkIDString()); - obj.pushKV("warnings", GetWarnings(false)); + obj.pushKV("warnings", GetWarnings(false).original); return obj; } @@ -874,7 +873,7 @@ static UniValue getblocktemplate(const JSONRPCRequest& request) result.pushKV("height", (int64_t)(pindexPrev->nHeight+1)); if (!pblocktemplate->vchCoinbaseCommitment.empty()) { - result.pushKV("default_witness_commitment", HexStr(pblocktemplate->vchCoinbaseCommitment.begin(), pblocktemplate->vchCoinbaseCommitment.end())); + result.pushKV("default_witness_commitment", HexStr(pblocktemplate->vchCoinbaseCommitment)); } return result; diff --git a/src/rpc/mining.h b/src/rpc/mining.h new file mode 100644 index 0000000000..acc74e1dcc --- /dev/null +++ b/src/rpc/mining.h @@ -0,0 +1,11 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_RPC_MINING_H +#define BITCOIN_RPC_MINING_H + +/** Default max iterations to try in RPC generatetodescriptor, generatetoaddress, and generateblock. */ +static const uint64_t DEFAULT_MAX_TRIES{1000000}; + +#endif // BITCOIN_RPC_MINING_H diff --git a/src/rpc/misc.cpp b/src/rpc/misc.cpp index ce98a7c937..53d38f4e11 100644 --- a/src/rpc/misc.cpp +++ b/src/rpc/misc.cpp @@ -63,7 +63,7 @@ static UniValue validateaddress(const JSONRPCRequest& request) ret.pushKV("address", currentAddress); CScript scriptPubKey = GetScriptForDestination(dest); - ret.pushKV("scriptPubKey", HexStr(scriptPubKey.begin(), scriptPubKey.end())); + ret.pushKV("scriptPubKey", HexStr(scriptPubKey)); UniValue detail = DescribeAddress(dest); ret.pushKVs(detail); @@ -131,7 +131,7 @@ static UniValue createmultisig(const JSONRPCRequest& request) UniValue result(UniValue::VOBJ); result.pushKV("address", EncodeDestination(dest)); - result.pushKV("redeemScript", HexStr(inner.begin(), inner.end())); + result.pushKV("redeemScript", HexStr(inner)); result.pushKV("descriptor", descriptor->ToString()); return result; diff --git a/src/rpc/net.cpp b/src/rpc/net.cpp index e29aa03695..77d24a6e79 100644 --- a/src/rpc/net.cpp +++ b/src/rpc/net.cpp @@ -22,6 +22,7 @@ #include <util/strencodings.h> #include <util/string.h> #include <util/system.h> +#include <util/translation.h> #include <validation.h> #include <version.h> #include <warnings.h> @@ -111,7 +112,7 @@ static UniValue getpeerinfo(const JSONRPCRequest& request) {RPCResult::Type::BOOL, "inbound", "Inbound (true) or Outbound (false)"}, {RPCResult::Type::BOOL, "addnode", "Whether connection was due to addnode/-connect or if it was an automatic/inbound connection"}, {RPCResult::Type::NUM, "startingheight", "The starting height (block) of the peer"}, - {RPCResult::Type::NUM, "banscore", "The ban score"}, + {RPCResult::Type::NUM, "banscore", "The ban score (DEPRECATED, returned only if config option -deprecatedrpc=banscore is passed)"}, {RPCResult::Type::NUM, "synced_headers", "The last header we have in common with this peer"}, {RPCResult::Type::NUM, "synced_blocks", "The last block we have in common with this peer"}, {RPCResult::Type::ARR, "inflight", "", @@ -190,7 +191,10 @@ static UniValue getpeerinfo(const JSONRPCRequest& request) obj.pushKV("addnode", stats.m_manual_connection); obj.pushKV("startingheight", stats.nStartingHeight); if (fStateStats) { - obj.pushKV("banscore", statestats.nMisbehavior); + if (IsDeprecatedRPCEnabled("banscore")) { + // banscore is deprecated in v0.21 for removal in v0.22 + obj.pushKV("banscore", statestats.nMisbehavior); + } obj.pushKV("synced_headers", statestats.nSyncHeight); obj.pushKV("synced_blocks", statestats.nCommonHeight); UniValue heights(UniValue::VARR); @@ -260,7 +264,7 @@ static UniValue addnode(const JSONRPCRequest& request) if (strCommand == "onetry") { CAddress addr; - node.connman->OpenNetworkConnection(addr, false, nullptr, strNode.c_str(), false, false, true); + node.connman->OpenNetworkConnection(addr, false, nullptr, strNode.c_str(), ConnectionType::MANUAL); return NullUniValue; } @@ -272,7 +276,7 @@ static UniValue addnode(const JSONRPCRequest& request) else if(strCommand == "remove") { if(!node.connman->RemoveAddedNode(strNode)) - throw JSONRPCError(RPC_CLIENT_NODE_NOT_ADDED, "Error: Node has not been added."); + throw JSONRPCError(RPC_CLIENT_NODE_NOT_ADDED, "Error: Node could not be removed. It has not been added previously."); } return NullUniValue; @@ -548,7 +552,7 @@ static UniValue getnetworkinfo(const JSONRPCRequest& request) } } obj.pushKV("localaddresses", localAddresses); - obj.pushKV("warnings", GetWarnings(false)); + obj.pushKV("warnings", GetWarnings(false).original); return obj; } @@ -613,12 +617,12 @@ static UniValue setban(const JSONRPCRequest& request) absolute = true; if (isSubnet) { - node.banman->Ban(subNet, BanReasonManuallyAdded, banTime, absolute); + node.banman->Ban(subNet, banTime, absolute); if (node.connman) { node.connman->DisconnectNode(subNet); } } else { - node.banman->Ban(netAddr, BanReasonManuallyAdded, banTime, absolute); + node.banman->Ban(netAddr, banTime, absolute); if (node.connman) { node.connman->DisconnectNode(netAddr); } @@ -627,7 +631,7 @@ static UniValue setban(const JSONRPCRequest& request) else if(strCommand == "remove") { if (!( isSubnet ? node.banman->Unban(subNet) : node.banman->Unban(netAddr) )) { - throw JSONRPCError(RPC_CLIENT_INVALID_IP_OR_SUBNET, "Error: Unban failed. Requested address/subnet was not previously banned."); + throw JSONRPCError(RPC_CLIENT_INVALID_IP_OR_SUBNET, "Error: Unban failed. Requested address/subnet was not previously manually banned."); } } return NullUniValue; @@ -636,7 +640,7 @@ static UniValue setban(const JSONRPCRequest& request) static UniValue listbanned(const JSONRPCRequest& request) { RPCHelpMan{"listbanned", - "\nList all banned IPs/Subnets.\n", + "\nList all manually banned IPs/Subnets.\n", {}, RPCResult{RPCResult::Type::ARR, "", "", { @@ -645,7 +649,6 @@ static UniValue listbanned(const JSONRPCRequest& request) {RPCResult::Type::STR, "address", ""}, {RPCResult::Type::NUM_TIME, "banned_until", ""}, {RPCResult::Type::NUM_TIME, "ban_created", ""}, - {RPCResult::Type::STR, "ban_reason", ""}, }}, }}, RPCExamples{ @@ -670,7 +673,6 @@ static UniValue listbanned(const JSONRPCRequest& request) rec.pushKV("address", entry.first.ToString()); rec.pushKV("banned_until", banEntry.nBanUntil); rec.pushKV("ban_created", banEntry.nCreateTime); - rec.pushKV("ban_reason", banEntry.banReasonToString()); bannedAddresses.push_back(rec); } @@ -725,7 +727,7 @@ static UniValue getnodeaddresses(const JSONRPCRequest& request) RPCHelpMan{"getnodeaddresses", "\nReturn known addresses which can potentially be used to find new nodes in the network\n", { - {"count", RPCArg::Type::NUM, /* default */ "1", "How many addresses to return. Limited to the smaller of " + ToString(ADDRMAN_GETADDR_MAX) + " or " + ToString(ADDRMAN_GETADDR_MAX_PCT) + "% of all known addresses."}, + {"count", RPCArg::Type::NUM, /* default */ "1", "The maximum number of addresses to return. Specify 0 to return all known addresses."}, }, RPCResult{ RPCResult::Type::ARR, "", "", @@ -752,18 +754,16 @@ static UniValue getnodeaddresses(const JSONRPCRequest& request) int count = 1; if (!request.params[0].isNull()) { count = request.params[0].get_int(); - if (count <= 0) { + if (count < 0) { throw JSONRPCError(RPC_INVALID_PARAMETER, "Address count out of range"); } } // returns a shuffled list of CAddress - std::vector<CAddress> vAddr = node.connman->GetAddresses(); + std::vector<CAddress> vAddr = node.connman->GetAddresses(count, /* max_pct */ 0); UniValue ret(UniValue::VARR); - int address_return_count = std::min<int>(count, vAddr.size()); - for (int i = 0; i < address_return_count; ++i) { + for (const CAddress& addr : vAddr) { UniValue obj(UniValue::VOBJ); - const CAddress& addr = vAddr[i]; obj.pushKV("time", (int)addr.nTime); obj.pushKV("services", (uint64_t)addr.nServices); obj.pushKV("address", addr.ToStringIP()); @@ -773,6 +773,54 @@ static UniValue getnodeaddresses(const JSONRPCRequest& request) return ret; } +static UniValue addpeeraddress(const JSONRPCRequest& request) +{ + RPCHelpMan{"addpeeraddress", + "\nAdd the address of a potential peer to the address manager. This RPC is for testing only.\n", + { + {"address", RPCArg::Type::STR, RPCArg::Optional::NO, "The IP address of the peer"}, + {"port", RPCArg::Type::NUM, RPCArg::Optional::NO, "The port of the peer"}, + }, + RPCResult{ + RPCResult::Type::OBJ, "", "", + { + {RPCResult::Type::BOOL, "success", "whether the peer address was successfully added to the address manager"}, + }, + }, + RPCExamples{ + HelpExampleCli("addpeeraddress", "\"1.2.3.4\" 8333") + + HelpExampleRpc("addpeeraddress", "\"1.2.3.4\", 8333") + }, + }.Check(request); + + NodeContext& node = EnsureNodeContext(request.context); + if (!node.connman) { + throw JSONRPCError(RPC_CLIENT_P2P_DISABLED, "Error: Peer-to-peer functionality missing or disabled"); + } + + UniValue obj(UniValue::VOBJ); + + std::string addr_string = request.params[0].get_str(); + uint16_t port = request.params[1].get_int(); + + CNetAddr net_addr; + if (!LookupHost(addr_string, net_addr, false)) { + obj.pushKV("success", false); + return obj; + } + CAddress address = CAddress({net_addr, port}, ServiceFlags(NODE_NETWORK|NODE_WITNESS)); + address.nTime = GetAdjustedTime(); + // The source address is set equal to the address. This is equivalent to the peer + // announcing itself. + if (!node.connman->AddNewAddresses({address}, address)) { + obj.pushKV("success", false); + return obj; + } + + obj.pushKV("success", true); + return obj; +} + void RegisterNetRPCCommands(CRPCTable &t) { // clang-format off @@ -792,6 +840,7 @@ static const CRPCCommand commands[] = { "network", "clearbanned", &clearbanned, {} }, { "network", "setnetworkactive", &setnetworkactive, {"state"} }, { "network", "getnodeaddresses", &getnodeaddresses, {"count"} }, + { "hidden", "addpeeraddress", &addpeeraddress, {"address", "port"} }, }; // clang-format on diff --git a/src/rpc/rawtransaction.cpp b/src/rpc/rawtransaction.cpp index e14217c307..9b6ef15785 100644 --- a/src/rpc/rawtransaction.cpp +++ b/src/rpc/rawtransaction.cpp @@ -28,6 +28,7 @@ #include <script/signingprovider.h> #include <script/standard.h> #include <uint256.h> +#include <util/bip32.h> #include <util/moneystr.h> #include <util/strencodings.h> #include <util/string.h> @@ -156,6 +157,8 @@ static UniValue getrawtransaction(const JSONRPCRequest& request) }, }.Check(request); + const NodeContext& node = EnsureNodeContext(request.context); + bool in_active_chain = true; uint256 hash = ParseHashV(request.params[0], "parameter 1"); CBlockIndex* blockindex = nullptr; @@ -187,9 +190,9 @@ static UniValue getrawtransaction(const JSONRPCRequest& request) f_txindex_ready = g_txindex->BlockUntilSyncedToCurrentChain(); } - CTransactionRef tx; uint256 hash_block; - if (!GetTransaction(hash, tx, Params().GetConsensus(), hash_block, blockindex)) { + const CTransactionRef tx = GetTransaction(blockindex, node.mempool, hash, Params().GetConsensus(), hash_block); + if (!tx) { std::string errmsg; if (blockindex) { if (!(blockindex->nStatus & BLOCK_HAVE_DATA)) { @@ -244,10 +247,11 @@ static UniValue gettxoutproof(const JSONRPCRequest& request) for (unsigned int idx = 0; idx < txids.size(); idx++) { const UniValue& txid = txids[idx]; uint256 hash(ParseHashV(txid, "txid")); - if (setTxids.count(hash)) - throw JSONRPCError(RPC_INVALID_PARAMETER, std::string("Invalid parameter, duplicated txid: ")+txid.get_str()); - setTxids.insert(hash); - oneTxid = hash; + if (setTxids.count(hash)) { + throw JSONRPCError(RPC_INVALID_PARAMETER, std::string("Invalid parameter, duplicated txid: ") + txid.get_str()); + } + setTxids.insert(hash); + oneTxid = hash; } CBlockIndex* pblockindex = nullptr; @@ -280,11 +284,11 @@ static UniValue gettxoutproof(const JSONRPCRequest& request) LOCK(cs_main); - if (pblockindex == nullptr) - { - CTransactionRef tx; - if (!GetTransaction(oneTxid, tx, Params().GetConsensus(), hashBlock) || hashBlock.IsNull()) + if (pblockindex == nullptr) { + const CTransactionRef tx = GetTransaction(/* block_index */ nullptr, /* mempool */ nullptr, oneTxid, Params().GetConsensus(), hashBlock); + if (!tx || hashBlock.IsNull()) { throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Transaction not yet in block"); + } pblockindex = LookupBlockIndex(hashBlock); if (!pblockindex) { throw JSONRPCError(RPC_INTERNAL_ERROR, "Transaction index corrupt"); @@ -292,20 +296,24 @@ static UniValue gettxoutproof(const JSONRPCRequest& request) } CBlock block; - if(!ReadBlockFromDisk(block, pblockindex, Params().GetConsensus())) + if (!ReadBlockFromDisk(block, pblockindex, Params().GetConsensus())) { throw JSONRPCError(RPC_INTERNAL_ERROR, "Can't read block from disk"); + } unsigned int ntxFound = 0; - for (const auto& tx : block.vtx) - if (setTxids.count(tx->GetHash())) + for (const auto& tx : block.vtx) { + if (setTxids.count(tx->GetHash())) { ntxFound++; - if (ntxFound != setTxids.size()) + } + } + if (ntxFound != setTxids.size()) { throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Not all transactions found in specified or retrieved block"); + } CDataStream ssMB(SER_NETWORK, PROTOCOL_VERSION | SERIALIZE_TRANSACTION_NO_WITNESS); CMerkleBlock mb(block, setTxids); ssMB << mb; - std::string strHex = HexStr(ssMB.begin(), ssMB.end()); + std::string strHex = HexStr(ssMB); return strHex; } @@ -511,12 +519,12 @@ static UniValue decoderawtransaction(const JSONRPCRequest& request) static std::string GetAllOutputTypes() { - std::string ret; - for (int i = TX_NONSTANDARD; i <= TX_WITNESS_UNKNOWN; ++i) { - if (i != TX_NONSTANDARD) ret += ", "; - ret += GetTxnOutputType(static_cast<txnouttype>(i)); + std::vector<std::string> ret; + using U = std::underlying_type<TxoutType>::type; + for (U i = (U)TxoutType::NONSTANDARD; i <= (U)TxoutType::WITNESS_UNKNOWN; ++i) { + ret.emplace_back(GetTxnOutputType(static_cast<TxoutType>(i))); } - return ret; + return Join(ret, ", "); } static UniValue decodescript(const JSONRPCRequest& request) @@ -580,10 +588,10 @@ static UniValue decodescript(const JSONRPCRequest& request) // is a witness program, don't return addresses for a segwit programs. if (type.get_str() == "pubkey" || type.get_str() == "pubkeyhash" || type.get_str() == "multisig" || type.get_str() == "nonstandard") { std::vector<std::vector<unsigned char>> solutions_data; - txnouttype which_type = Solver(script, solutions_data); + TxoutType which_type = Solver(script, solutions_data); // Uncompressed pubkeys cannot be used with segwit checksigs. // If the script contains an uncompressed pubkey, skip encoding of a segwit program. - if ((which_type == TX_PUBKEY) || (which_type == TX_MULTISIG)) { + if ((which_type == TxoutType::PUBKEY) || (which_type == TxoutType::MULTISIG)) { for (const auto& solution : solutions_data) { if ((solution.size() != 1) && !CPubKey(solution).IsCompressed()) { return r; @@ -592,10 +600,10 @@ static UniValue decodescript(const JSONRPCRequest& request) } UniValue sr(UniValue::VOBJ); CScript segwitScr; - if (which_type == TX_PUBKEY) { - segwitScr = GetScriptForDestination(WitnessV0KeyHash(Hash160(solutions_data[0].begin(), solutions_data[0].end()))); - } else if (which_type == TX_PUBKEYHASH) { - segwitScr = GetScriptForDestination(WitnessV0KeyHash(solutions_data[0])); + if (which_type == TxoutType::PUBKEY) { + segwitScr = GetScriptForDestination(WitnessV0KeyHash(Hash160(solutions_data[0]))); + } else if (which_type == TxoutType::PUBKEYHASH) { + segwitScr = GetScriptForDestination(WitnessV0KeyHash(uint160{solutions_data[0]})); } else { // Scripts that are not fit for P2WPKH are encoded as P2WSH. // Newer segwit program versions should be considered when then become available. @@ -938,25 +946,6 @@ static UniValue testmempoolaccept(const JSONRPCRequest& request) return result; } -static std::string WriteHDKeypath(std::vector<uint32_t>& keypath) -{ - std::string keypath_str = "m"; - for (uint32_t num : keypath) { - keypath_str += "/"; - bool hardened = false; - if (num & 0x80000000) { - hardened = true; - num &= ~0x80000000; - } - - keypath_str += ToString(num); - if (hardened) { - keypath_str += "'"; - } - } - return keypath_str; -} - UniValue decodepsbt(const JSONRPCRequest& request) { RPCHelpMan{"decodepsbt", @@ -1104,30 +1093,34 @@ UniValue decodepsbt(const JSONRPCRequest& request) const PSBTInput& input = psbtx.inputs[i]; UniValue in(UniValue::VOBJ); // UTXOs + bool have_a_utxo = false; + CTxOut txout; if (!input.witness_utxo.IsNull()) { - const CTxOut& txout = input.witness_utxo; - - UniValue out(UniValue::VOBJ); - - out.pushKV("amount", ValueFromAmount(txout.nValue)); - if (MoneyRange(txout.nValue) && MoneyRange(total_in + txout.nValue)) { - total_in += txout.nValue; - } else { - // Hack to just not show fee later - have_all_utxos = false; - } + txout = input.witness_utxo; UniValue o(UniValue::VOBJ); ScriptToUniv(txout.scriptPubKey, o, true); + + UniValue out(UniValue::VOBJ); + out.pushKV("amount", ValueFromAmount(txout.nValue)); out.pushKV("scriptPubKey", o); + in.pushKV("witness_utxo", out); - } else if (input.non_witness_utxo) { + + have_a_utxo = true; + } + if (input.non_witness_utxo) { + txout = input.non_witness_utxo->vout[psbtx.tx->vin[i].prevout.n]; + UniValue non_wit(UniValue::VOBJ); TxToUniv(*input.non_witness_utxo, uint256(), non_wit, false); in.pushKV("non_witness_utxo", non_wit); - CAmount utxo_val = input.non_witness_utxo->vout[psbtx.tx->vin[i].prevout.n].nValue; - if (MoneyRange(utxo_val) && MoneyRange(total_in + utxo_val)) { - total_in += utxo_val; + + have_a_utxo = true; + } + if (have_a_utxo) { + if (MoneyRange(txout.nValue) && MoneyRange(total_in + txout.nValue)) { + total_in += txout.nValue; } else { // Hack to just not show fee later have_all_utxos = false; @@ -1186,7 +1179,7 @@ UniValue decodepsbt(const JSONRPCRequest& request) if (!input.final_script_witness.IsNull()) { UniValue txinwitness(UniValue::VARR); for (const auto& item : input.final_script_witness.stack) { - txinwitness.push_back(HexStr(item.begin(), item.end())); + txinwitness.push_back(HexStr(item)); } in.pushKV("final_scriptwitness", txinwitness); } @@ -1355,7 +1348,7 @@ UniValue finalizepsbt(const JSONRPCRequest& request) if (complete && extract) { ssTx << mtx; - result_str = HexStr(ssTx.str()); + result_str = HexStr(ssTx); result.pushKV("hex", result_str); } else { ssTx << psbtx; @@ -1628,7 +1621,7 @@ UniValue joinpsbts(const JSONRPCRequest& request) throw JSONRPCError(RPC_INVALID_PARAMETER, "At least two PSBTs are required to join PSBTs."); } - int32_t best_version = 1; + uint32_t best_version = 1; uint32_t best_locktime = 0xffffffff; for (unsigned int i = 0; i < txs.size(); ++i) { PartiallySignedTransaction psbtx; @@ -1638,8 +1631,8 @@ UniValue joinpsbts(const JSONRPCRequest& request) } psbtxs.push_back(psbtx); // Choose the highest version number - if (psbtx.tx->nVersion > best_version) { - best_version = psbtx.tx->nVersion; + if (static_cast<uint32_t>(psbtx.tx->nVersion) > best_version) { + best_version = static_cast<uint32_t>(psbtx.tx->nVersion); } // Choose the lowest lock time if (psbtx.tx->nLockTime < best_locktime) { @@ -1650,7 +1643,7 @@ UniValue joinpsbts(const JSONRPCRequest& request) // Create a blank psbt where everything will be added PartiallySignedTransaction merged_psbt; merged_psbt.tx = CMutableTransaction(); - merged_psbt.tx->nVersion = best_version; + merged_psbt.tx->nVersion = static_cast<int32_t>(best_version); merged_psbt.tx->nLockTime = best_locktime; // Merge diff --git a/src/rpc/rawtransaction_util.cpp b/src/rpc/rawtransaction_util.cpp index bd82773bf2..1031716b4a 100644 --- a/src/rpc/rawtransaction_util.cpp +++ b/src/rpc/rawtransaction_util.cpp @@ -138,10 +138,10 @@ static void TxInErrorToJSON(const CTxIn& txin, UniValue& vErrorsRet, const std:: entry.pushKV("vout", (uint64_t)txin.prevout.n); UniValue witness(UniValue::VARR); for (unsigned int i = 0; i < txin.scriptWitness.stack.size(); i++) { - witness.push_back(HexStr(txin.scriptWitness.stack[i].begin(), txin.scriptWitness.stack[i].end())); + witness.push_back(HexStr(txin.scriptWitness.stack[i])); } entry.pushKV("witness", witness); - entry.pushKV("scriptSig", HexStr(txin.scriptSig.begin(), txin.scriptSig.end())); + entry.pushKV("scriptSig", HexStr(txin.scriptSig)); entry.pushKV("sequence", (uint64_t)txin.nSequence); entry.pushKV("error", strMessage); vErrorsRet.push_back(entry); diff --git a/src/rpc/request.cpp b/src/rpc/request.cpp index 7fef45f50e..d9ad70fa37 100644 --- a/src/rpc/request.cpp +++ b/src/rpc/request.cpp @@ -78,7 +78,7 @@ bool GenerateAuthCookie(std::string *cookie_out) const size_t COOKIE_SIZE = 32; unsigned char rand_pwd[COOKIE_SIZE]; GetRandBytes(rand_pwd, COOKIE_SIZE); - std::string cookie = COOKIEAUTH_USER + ":" + HexStr(rand_pwd, rand_pwd+COOKIE_SIZE); + std::string cookie = COOKIEAUTH_USER + ":" + HexStr(rand_pwd); /** the umask determines what permissions are used to create this file - * these are set to 077 in init.cpp unless overridden with -sysperms. diff --git a/src/rpc/request.h b/src/rpc/request.h index 02ec5393a7..4761e9e371 100644 --- a/src/rpc/request.h +++ b/src/rpc/request.h @@ -41,6 +41,16 @@ public: const util::Ref& context; JSONRPCRequest(const util::Ref& context) : id(NullUniValue), params(NullUniValue), fHelp(false), context(context) {} + + //! Initializes request information from another request object and the + //! given context. The implementation should be updated if any members are + //! added or removed above. + JSONRPCRequest(const JSONRPCRequest& other, const util::Ref& context) + : id(other.id), strMethod(other.strMethod), params(other.params), fHelp(other.fHelp), URI(other.URI), + authUser(other.authUser), peerAddr(other.peerAddr), context(context) + { + } + void parse(const UniValue& valRequest); }; diff --git a/src/rpc/server.cpp b/src/rpc/server.cpp index 2a0079ac39..e5f6b1b9f1 100644 --- a/src/rpc/server.cpp +++ b/src/rpc/server.cpp @@ -20,12 +20,10 @@ #include <mutex> #include <unordered_map> -static RecursiveMutex cs_rpcWarmup; +static Mutex g_rpc_warmup_mutex; static std::atomic<bool> g_rpc_running{false}; -static std::once_flag g_rpc_interrupt_flag; -static std::once_flag g_rpc_stop_flag; -static bool fRPCInWarmup GUARDED_BY(cs_rpcWarmup) = true; -static std::string rpcWarmupStatus GUARDED_BY(cs_rpcWarmup) = "RPC server started"; +static bool fRPCInWarmup GUARDED_BY(g_rpc_warmup_mutex) = true; +static std::string rpcWarmupStatus GUARDED_BY(g_rpc_warmup_mutex) = "RPC server started"; /* Timer-creating functions */ static RPCTimerInterface* timerInterface = nullptr; /* Map of name to timer. */ @@ -132,11 +130,9 @@ std::string CRPCTable::help(const std::string& strCommand, const JSONRPCRequest& return strRet; } -UniValue help(const JSONRPCRequest& jsonRequest) +static RPCHelpMan help() { - if (jsonRequest.fHelp || jsonRequest.params.size() > 1) - throw std::runtime_error( - RPCHelpMan{"help", + return RPCHelpMan{"help", "\nList all commands, or get help for a specified command.\n", { {"command", RPCArg::Type::STR, /* default */ "all commands", "The command to get help on"}, @@ -145,32 +141,32 @@ UniValue help(const JSONRPCRequest& jsonRequest) RPCResult::Type::STR, "", "The help text" }, RPCExamples{""}, - }.ToString() - ); - + [&](const RPCHelpMan& self, const JSONRPCRequest& jsonRequest) -> UniValue +{ std::string strCommand; if (jsonRequest.params.size() > 0) strCommand = jsonRequest.params[0].get_str(); return tableRPC.help(strCommand, jsonRequest); +}, + }; } - -UniValue stop(const JSONRPCRequest& jsonRequest) +static RPCHelpMan stop() { static const std::string RESULT{PACKAGE_NAME " stopping"}; - // Accept the deprecated and ignored 'detach' boolean argument + return RPCHelpMan{"stop", // Also accept the hidden 'wait' integer argument (milliseconds) // For instance, 'stop 1000' makes the call wait 1 second before returning // to the client (intended for testing) - if (jsonRequest.fHelp || jsonRequest.params.size() > 1) - throw std::runtime_error( - RPCHelpMan{"stop", "\nRequest a graceful shutdown of " PACKAGE_NAME ".", - {}, + { + {"wait", RPCArg::Type::NUM, RPCArg::Optional::OMITTED_NAMED_ARG, "how long to wait in ms", "", {}, /* hidden */ true}, + }, RPCResult{RPCResult::Type::STR, "", "A string with the content '" + RESULT + "'"}, RPCExamples{""}, - }.ToString()); + [&](const RPCHelpMan& self, const JSONRPCRequest& jsonRequest) -> UniValue +{ // Event loop will exit after current HTTP requests have been handled, so // this reply will get back to the client. StartShutdown(); @@ -178,11 +174,13 @@ UniValue stop(const JSONRPCRequest& jsonRequest) UninterruptibleSleep(std::chrono::milliseconds{jsonRequest.params[0].get_int()}); } return RESULT; +}, + }; } -static UniValue uptime(const JSONRPCRequest& jsonRequest) +static RPCHelpMan uptime() { - RPCHelpMan{"uptime", + return RPCHelpMan{"uptime", "\nReturns the total uptime of the server.\n", {}, RPCResult{ @@ -192,14 +190,16 @@ static UniValue uptime(const JSONRPCRequest& jsonRequest) HelpExampleCli("uptime", "") + HelpExampleRpc("uptime", "") }, - }.Check(jsonRequest); - + [&](const RPCHelpMan& self, const JSONRPCRequest& request) -> UniValue +{ return GetTime() - GetStartupTime(); } + }; +} -static UniValue getrpcinfo(const JSONRPCRequest& request) +static RPCHelpMan getrpcinfo() { - RPCHelpMan{"getrpcinfo", + return RPCHelpMan{"getrpcinfo", "\nReturns details of the RPC server.\n", {}, RPCResult{ @@ -219,8 +219,8 @@ static UniValue getrpcinfo(const JSONRPCRequest& request) RPCExamples{ HelpExampleCli("getrpcinfo", "") + HelpExampleRpc("getrpcinfo", "")}, - }.Check(request); - + [&](const RPCHelpMan& self, const JSONRPCRequest& request) -> UniValue +{ LOCK(g_rpc_server_info.mutex); UniValue active_commands(UniValue::VARR); for (const RPCCommandExecutionInfo& info : g_rpc_server_info.active_commands) { @@ -239,6 +239,8 @@ static UniValue getrpcinfo(const JSONRPCRequest& request) return result; } + }; +} // clang-format off static const CRPCCommand vRPCCommands[] = @@ -295,6 +297,7 @@ void StartRPC() void InterruptRPC() { + static std::once_flag g_rpc_interrupt_flag; // This function could be called twice if the GUI has been started with -server=1. std::call_once(g_rpc_interrupt_flag, []() { LogPrint(BCLog::RPC, "Interrupting RPC\n"); @@ -305,6 +308,7 @@ void InterruptRPC() void StopRPC() { + static std::once_flag g_rpc_stop_flag; // This function could be called twice if the GUI has been started with -server=1. assert(!g_rpc_running); std::call_once(g_rpc_stop_flag, []() { @@ -327,20 +331,20 @@ void RpcInterruptionPoint() void SetRPCWarmupStatus(const std::string& newStatus) { - LOCK(cs_rpcWarmup); + LOCK(g_rpc_warmup_mutex); rpcWarmupStatus = newStatus; } void SetRPCWarmupFinished() { - LOCK(cs_rpcWarmup); + LOCK(g_rpc_warmup_mutex); assert(fRPCInWarmup); fRPCInWarmup = false; } bool RPCIsInWarmup(std::string *outStatus) { - LOCK(cs_rpcWarmup); + LOCK(g_rpc_warmup_mutex); if (outStatus) *outStatus = rpcWarmupStatus; return fRPCInWarmup; @@ -439,7 +443,7 @@ UniValue CRPCTable::execute(const JSONRPCRequest &request) const { // Return immediately if in warmup { - LOCK(cs_rpcWarmup); + LOCK(g_rpc_warmup_mutex); if (fRPCInWarmup) throw JSONRPCError(RPC_IN_WARMUP, rpcWarmupStatus); } diff --git a/src/rpc/server.h b/src/rpc/server.h index d7a04ff6e8..6da3e94ea2 100644 --- a/src/rpc/server.h +++ b/src/rpc/server.h @@ -8,6 +8,7 @@ #include <amount.h> #include <rpc/request.h> +#include <rpc/util.h> #include <functional> #include <map> @@ -85,6 +86,7 @@ void RPCUnsetTimerInterface(RPCTimerInterface *iface); void RPCRunLater(const std::string& name, std::function<void()> func, int64_t nSeconds); typedef UniValue(*rpcfn_type)(const JSONRPCRequest& jsonRequest); +typedef RPCHelpMan (*RpcMethodFnType)(); class CRPCCommand { @@ -101,6 +103,19 @@ public: { } + //! Simplified constructor taking plain RpcMethodFnType function pointer. + CRPCCommand(std::string category, std::string name_in, RpcMethodFnType fn, std::vector<std::string> args_in) + : CRPCCommand( + category, + fn().m_name, + [fn](const JSONRPCRequest& request, UniValue& result, bool) { result = fn().HandleRequest(request); return true; }, + fn().GetArgNames(), + intptr_t(fn)) + { + CHECK_NONFATAL(fn().m_name == name_in); + CHECK_NONFATAL(fn().GetArgNames() == args_in); + } + //! Simplified constructor taking plain rpcfn_type function pointer. CRPCCommand(const char* category, const char* name, rpcfn_type fn, std::initializer_list<const char*> args) : CRPCCommand(category, name, @@ -117,7 +132,7 @@ public: }; /** - * Bitcoin RPC command dispatcher. + * RPC command dispatcher. */ class CRPCTable { diff --git a/src/rpc/util.cpp b/src/rpc/util.cpp index 39bf05fbbd..073a7688a9 100644 --- a/src/rpc/util.cpp +++ b/src/rpc/util.cpp @@ -10,6 +10,7 @@ #include <tinyformat.h> #include <util/strencodings.h> #include <util/string.h> +#include <util/translation.h> #include <tuple> @@ -112,6 +113,23 @@ std::vector<unsigned char> ParseHexO(const UniValue& o, std::string strKey) return ParseHexV(find_value(o, strKey), strKey); } +CoinStatsHashType ParseHashType(const UniValue& param, const CoinStatsHashType default_type) +{ + if (param.isNull()) { + return default_type; + } else { + std::string hash_type_input = param.get_str(); + + if (hash_type_input == "hash_serialized_2") { + return CoinStatsHashType::HASH_SERIALIZED; + } else if (hash_type_input == "none") { + return CoinStatsHashType::NONE; + } else { + throw JSONRPCError(RPC_INVALID_PARAMETER, strprintf("%d is not a valid hash_type", hash_type_input)); + } + } +} + std::string HelpExampleCli(const std::string& methodname, const std::string& args) { return "> bitcoin-cli " + methodname + " " + args + "\n"; @@ -223,7 +241,7 @@ public: obj.pushKV("isscript", false); obj.pushKV("iswitness", true); obj.pushKV("witness_version", 0); - obj.pushKV("witness_program", HexStr(id.begin(), id.end())); + obj.pushKV("witness_program", HexStr(id)); return obj; } @@ -233,7 +251,7 @@ public: obj.pushKV("isscript", true); obj.pushKV("iswitness", true); obj.pushKV("witness_version", 0); - obj.pushKV("witness_program", HexStr(id.begin(), id.end())); + obj.pushKV("witness_program", HexStr(id)); return obj; } @@ -242,7 +260,7 @@ public: UniValue obj(UniValue::VOBJ); obj.pushKV("iswitness", true); obj.pushKV("witness_version", (int)id.version); - obj.pushKV("witness_program", HexStr(id.program, id.program + id.length)); + obj.pushKV("witness_program", HexStr(Span<const unsigned char>(id.program, id.length))); return obj; } }; @@ -285,7 +303,7 @@ UniValue JSONRPCTransactionError(TransactionError terr, const std::string& err_s if (err_string.length() > 0) { return JSONRPCError(RPCErrorFromTransactionError(terr), err_string); } else { - return JSONRPCError(RPCErrorFromTransactionError(terr), TransactionErrorString(terr)); + return JSONRPCError(RPCErrorFromTransactionError(terr), TransactionErrorString(terr).original); } } @@ -367,9 +385,7 @@ struct Sections { PushSection({indent + "]" + (outer_type != OuterType::NONE ? "," : ""), ""}); break; } - - // no default case, so the compiler can warn about missing cases - } + } // no default case, so the compiler can warn about missing cases } /** @@ -380,6 +396,9 @@ struct Sections { std::string ret; const size_t pad = m_max_pad + 4; for (const auto& s : m_sections) { + // The left part of a section is assumed to be a single line, usually it is the name of the JSON struct or a + // brace like {, }, [, or ] + CHECK_NONFATAL(s.m_left.find('\n') == std::string::npos); if (s.m_right.empty()) { ret += s.m_left; ret += "\n"; @@ -414,7 +433,11 @@ struct Sections { }; RPCHelpMan::RPCHelpMan(std::string name, std::string description, std::vector<RPCArg> args, RPCResults results, RPCExamples examples) + : RPCHelpMan{std::move(name), std::move(description), std::move(args), std::move(results), std::move(examples), nullptr} {} + +RPCHelpMan::RPCHelpMan(std::string name, std::string description, std::vector<RPCArg> args, RPCResults results, RPCExamples examples, RPCMethodImpl fun) : m_name{std::move(name)}, + m_fun{std::move(fun)}, m_description{std::move(description)}, m_args{std::move(args)}, m_results{std::move(results)}, @@ -463,6 +486,16 @@ bool RPCHelpMan::IsValidNumArgs(size_t num_args) const } return num_required_args <= num_args && num_args <= m_args.size(); } + +std::vector<std::string> RPCHelpMan::GetArgNames() const +{ + std::vector<std::string> ret; + for (const auto& arg : m_args) { + ret.emplace_back(arg.m_names); + } + return ret; +} + std::string RPCHelpMan::ToString() const { std::string ret; @@ -471,6 +504,7 @@ std::string RPCHelpMan::ToString() const ret += m_name; bool was_optional{false}; for (const auto& arg : m_args) { + if (arg.m_hidden) continue; const bool optional = arg.IsOptional(); ret += " "; if (optional) { @@ -492,6 +526,7 @@ std::string RPCHelpMan::ToString() const Sections sections; for (size_t i{0}; i < m_args.size(); ++i) { const auto& arg = m_args.at(i); + if (arg.m_hidden) continue; if (i == 0) ret += "\nArguments:\n"; @@ -571,9 +606,7 @@ std::string RPCArg::ToDescriptionString() const ret += "json array"; break; } - - // no default case, so the compiler can warn about missing cases - } + } // no default case, so the compiler can warn about missing cases } if (m_fallback.which() == 1) { ret += ", optional, default=" + boost::get<std::string>(m_fallback); @@ -591,9 +624,7 @@ std::string RPCArg::ToDescriptionString() const ret += ", required"; break; } - - // no default case, so the compiler can warn about missing cases - } + } // no default case, so the compiler can warn about missing cases } ret += ")"; ret += m_description.empty() ? "" : " " + m_description; @@ -688,10 +719,7 @@ void RPCResult::ToSections(Sections& sections, const OuterType outer_type, const sections.PushSection({indent + "}" + maybe_separator, ""}); return; } - - // no default case, so the compiler can warn about missing cases - } - + } // no default case, so the compiler can warn about missing cases CHECK_NONFATAL(false); } @@ -728,9 +756,7 @@ std::string RPCArg::ToStringObj(const bool oneline) const case Type::OBJ_USER_KEYS: // Currently unused, so avoid writing dead code CHECK_NONFATAL(false); - - // no default case, so the compiler can warn about missing cases - } + } // no default case, so the compiler can warn about missing cases CHECK_NONFATAL(false); } @@ -765,9 +791,7 @@ std::string RPCArg::ToString(const bool oneline) const } return "[" + res + "...]"; } - - // no default case, so the compiler can warn about missing cases - } + } // no default case, so the compiler can warn about missing cases CHECK_NONFATAL(false); } diff --git a/src/rpc/util.h b/src/rpc/util.h index 53dce2c397..45b0bb0c7e 100644 --- a/src/rpc/util.h +++ b/src/rpc/util.h @@ -5,6 +5,7 @@ #ifndef BITCOIN_RPC_UTIL_H #define BITCOIN_RPC_UTIL_H +#include <node/coinstats.h> #include <node/transaction.h> #include <outputtype.h> #include <protocol.h> @@ -77,6 +78,8 @@ extern uint256 ParseHashO(const UniValue& o, std::string strKey); extern std::vector<unsigned char> ParseHexV(const UniValue& v, std::string strName); extern std::vector<unsigned char> ParseHexO(const UniValue& o, std::string strKey); +CoinStatsHashType ParseHashType(const UniValue& param, const CoinStatsHashType default_type); + extern CAmount AmountFromValue(const UniValue& value); extern std::string HelpExampleCli(const std::string& methodname, const std::string& args); extern std::string HelpExampleRpc(const std::string& methodname, const std::string& args); @@ -144,6 +147,7 @@ struct RPCArg { using Fallback = boost::variant<Optional, /* default value for optional args */ std::string>; const std::string m_names; //!< The name of the arg (can be empty for inner args, can contain multiple aliases separated by | for named request arguments) const Type m_type; + const bool m_hidden; const std::vector<RPCArg> m_inner; //!< Only used for arrays or dicts const Fallback m_fallback; const std::string m_description; @@ -156,9 +160,11 @@ struct RPCArg { const Fallback fallback, const std::string description, const std::string oneline_description = "", - const std::vector<std::string> type_str = {}) + const std::vector<std::string> type_str = {}, + const bool hidden = false) : m_names{std::move(name)}, m_type{std::move(type)}, + m_hidden{hidden}, m_fallback{std::move(fallback)}, m_description{std::move(description)}, m_oneline_description{std::move(oneline_description)}, @@ -177,6 +183,7 @@ struct RPCArg { const std::vector<std::string> type_str = {}) : m_names{std::move(name)}, m_type{std::move(type)}, + m_hidden{false}, m_inner{std::move(inner)}, m_fallback{std::move(fallback)}, m_description{std::move(description)}, @@ -326,8 +333,15 @@ class RPCHelpMan { public: RPCHelpMan(std::string name, std::string description, std::vector<RPCArg> args, RPCResults results, RPCExamples examples); + using RPCMethodImpl = std::function<UniValue(const RPCHelpMan&, const JSONRPCRequest&)>; + RPCHelpMan(std::string name, std::string description, std::vector<RPCArg> args, RPCResults results, RPCExamples examples, RPCMethodImpl fun); std::string ToString() const; + UniValue HandleRequest(const JSONRPCRequest& request) + { + Check(request); + return m_fun(*this, request); + } /** If the supplied number of args is neither too small nor too high */ bool IsValidNumArgs(size_t num_args) const; /** @@ -340,8 +354,12 @@ public: } } -private: + std::vector<std::string> GetArgNames() const; + const std::string m_name; + +private: + const RPCMethodImpl m_fun; const std::string m_description; const std::vector<RPCArg> m_args; const RPCResults m_results; diff --git a/src/scheduler.cpp b/src/scheduler.cpp index c4bd47310b..7c361bf26f 100644 --- a/src/scheduler.cpp +++ b/src/scheduler.cpp @@ -30,9 +30,6 @@ void CScheduler::serviceQueue() // is called. while (!shouldStop()) { try { - if (!shouldStop() && taskQueue.empty()) { - REVERSE_LOCK(lock); - } while (!shouldStop() && taskQueue.empty()) { // Wait until there is something to do. newTaskScheduled.wait(lock); @@ -71,18 +68,6 @@ void CScheduler::serviceQueue() newTaskScheduled.notify_one(); } -void CScheduler::stop(bool drain) -{ - { - LOCK(newTaskMutex); - if (drain) - stopWhenEmpty = true; - else - stopRequested = true; - } - newTaskScheduled.notify_all(); -} - void CScheduler::schedule(CScheduler::Function f, std::chrono::system_clock::time_point t) { { @@ -125,8 +110,8 @@ void CScheduler::scheduleEvery(CScheduler::Function f, std::chrono::milliseconds scheduleFromNow([=] { Repeat(*this, f, delta); }, delta); } -size_t CScheduler::getQueueInfo(std::chrono::system_clock::time_point &first, - std::chrono::system_clock::time_point &last) const +size_t CScheduler::getQueueInfo(std::chrono::system_clock::time_point& first, + std::chrono::system_clock::time_point& last) const { LOCK(newTaskMutex); size_t result = taskQueue.size(); @@ -137,13 +122,15 @@ size_t CScheduler::getQueueInfo(std::chrono::system_clock::time_point &first, return result; } -bool CScheduler::AreThreadsServicingQueue() const { +bool CScheduler::AreThreadsServicingQueue() const +{ LOCK(newTaskMutex); return nThreadsServicingQueue; } -void SingleThreadedSchedulerClient::MaybeScheduleProcessQueue() { +void SingleThreadedSchedulerClient::MaybeScheduleProcessQueue() +{ { LOCK(m_cs_callbacks_pending); // Try to avoid scheduling too many copies here, but if we @@ -155,8 +142,9 @@ void SingleThreadedSchedulerClient::MaybeScheduleProcessQueue() { m_pscheduler->schedule(std::bind(&SingleThreadedSchedulerClient::ProcessQueue, this), std::chrono::system_clock::now()); } -void SingleThreadedSchedulerClient::ProcessQueue() { - std::function<void ()> callback; +void SingleThreadedSchedulerClient::ProcessQueue() +{ + std::function<void()> callback; { LOCK(m_cs_callbacks_pending); if (m_are_callbacks_running) return; @@ -172,7 +160,8 @@ void SingleThreadedSchedulerClient::ProcessQueue() { struct RAIICallbacksRunning { SingleThreadedSchedulerClient* instance; explicit RAIICallbacksRunning(SingleThreadedSchedulerClient* _instance) : instance(_instance) {} - ~RAIICallbacksRunning() { + ~RAIICallbacksRunning() + { { LOCK(instance->m_cs_callbacks_pending); instance->m_are_callbacks_running = false; @@ -184,7 +173,8 @@ void SingleThreadedSchedulerClient::ProcessQueue() { callback(); } -void SingleThreadedSchedulerClient::AddToProcessQueue(std::function<void ()> func) { +void SingleThreadedSchedulerClient::AddToProcessQueue(std::function<void()> func) +{ assert(m_pscheduler); { @@ -194,7 +184,8 @@ void SingleThreadedSchedulerClient::AddToProcessQueue(std::function<void ()> fun MaybeScheduleProcessQueue(); } -void SingleThreadedSchedulerClient::EmptyQueue() { +void SingleThreadedSchedulerClient::EmptyQueue() +{ assert(!m_pscheduler->AreThreadsServicingQueue()); bool should_continue = true; while (should_continue) { @@ -204,7 +195,8 @@ void SingleThreadedSchedulerClient::EmptyQueue() { } } -size_t SingleThreadedSchedulerClient::CallbacksPending() { +size_t SingleThreadedSchedulerClient::CallbacksPending() +{ LOCK(m_cs_callbacks_pending); return m_callbacks_pending.size(); } diff --git a/src/scheduler.h b/src/scheduler.h index 1e64195484..d7fe00d1b4 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -5,11 +5,6 @@ #ifndef BITCOIN_SCHEDULER_H #define BITCOIN_SCHEDULER_H -// -// NOTE: -// boost::thread should be ported to std::thread -// when we support C++11. -// #include <condition_variable> #include <functional> #include <list> @@ -17,24 +12,23 @@ #include <sync.h> -// -// Simple class for background tasks that should be run -// periodically or once "after a while" -// -// Usage: -// -// CScheduler* s = new CScheduler(); -// s->scheduleFromNow(doSomething, std::chrono::milliseconds{11}); // Assuming a: void doSomething() { } -// s->scheduleFromNow([=] { this->func(argument); }, std::chrono::milliseconds{3}); -// boost::thread* t = new boost::thread(std::bind(CScheduler::serviceQueue, s)); -// -// ... then at program shutdown, make sure to call stop() to clean up the thread(s) running serviceQueue: -// s->stop(); -// t->join(); -// delete t; -// delete s; // Must be done after thread is interrupted/joined. -// - +/** + * Simple class for background tasks that should be run + * periodically or once "after a while" + * + * Usage: + * + * CScheduler* s = new CScheduler(); + * s->scheduleFromNow(doSomething, std::chrono::milliseconds{11}); // Assuming a: void doSomething() { } + * s->scheduleFromNow([=] { this->func(argument); }, std::chrono::milliseconds{3}); + * std::thread* t = new std::thread([&] { s->serviceQueue(); }); + * + * ... then at program shutdown, make sure to call stop() to clean up the thread(s) running serviceQueue: + * s->stop(); + * t->join(); + * delete t; + * delete s; // Must be done after thread is interrupted/joined. + */ class CScheduler { public: @@ -43,7 +37,7 @@ public: typedef std::function<void()> Function; - // Call func at/after time t + /** Call func at/after time t */ void schedule(Function f, std::chrono::system_clock::time_point t); /** Call f once after the delta has passed */ @@ -67,23 +61,33 @@ public: */ void MockForward(std::chrono::seconds delta_seconds); - // To keep things as simple as possible, there is no unschedule. - - // Services the queue 'forever'. Should be run in a thread, - // and interrupted using boost::interrupt_thread + /** + * Services the queue 'forever'. Should be run in a thread, + * and interrupted using boost::interrupt_thread + */ void serviceQueue(); - // Tell any threads running serviceQueue to stop as soon as they're - // done servicing whatever task they're currently servicing (drain=false) - // or when there is no work left to be done (drain=true) - void stop(bool drain=false); + /** Tell any threads running serviceQueue to stop as soon as the current task is done */ + void stop() + { + WITH_LOCK(newTaskMutex, stopRequested = true); + newTaskScheduled.notify_all(); + } + /** Tell any threads running serviceQueue to stop when there is no work left to be done */ + void StopWhenDrained() + { + WITH_LOCK(newTaskMutex, stopWhenEmpty = true); + newTaskScheduled.notify_all(); + } - // Returns number of tasks waiting to be serviced, - // and first and last task times - size_t getQueueInfo(std::chrono::system_clock::time_point &first, - std::chrono::system_clock::time_point &last) const; + /** + * Returns number of tasks waiting to be serviced, + * and first and last task times + */ + size_t getQueueInfo(std::chrono::system_clock::time_point& first, + std::chrono::system_clock::time_point& last) const; - // Returns true if there are threads actively running in serviceQueue() + /** Returns true if there are threads actively running in serviceQueue() */ bool AreThreadsServicingQueue() const; private: @@ -106,19 +110,20 @@ private: * B() will be able to observe all of the effects of callback A() which executed * before it. */ -class SingleThreadedSchedulerClient { +class SingleThreadedSchedulerClient +{ private: - CScheduler *m_pscheduler; + CScheduler* m_pscheduler; RecursiveMutex m_cs_callbacks_pending; - std::list<std::function<void ()>> m_callbacks_pending GUARDED_BY(m_cs_callbacks_pending); + std::list<std::function<void()>> m_callbacks_pending GUARDED_BY(m_cs_callbacks_pending); bool m_are_callbacks_running GUARDED_BY(m_cs_callbacks_pending) = false; void MaybeScheduleProcessQueue(); void ProcessQueue(); public: - explicit SingleThreadedSchedulerClient(CScheduler *pschedulerIn) : m_pscheduler(pschedulerIn) {} + explicit SingleThreadedSchedulerClient(CScheduler* pschedulerIn) : m_pscheduler(pschedulerIn) {} /** * Add a callback to be executed. Callbacks are executed serially @@ -126,10 +131,12 @@ public: * Practically, this means that callbacks can behave as if they are executed * in order by a single thread. */ - void AddToProcessQueue(std::function<void ()> func); + void AddToProcessQueue(std::function<void()> func); - // Processes all remaining queue members on the calling thread, blocking until queue is empty - // Must be called after the CScheduler has no remaining processing threads! + /** + * Processes all remaining queue members on the calling thread, blocking until queue is empty + * Must be called after the CScheduler has no remaining processing threads! + */ void EmptyQueue(); size_t CallbacksPending(); diff --git a/src/script/descriptor.cpp b/src/script/descriptor.cpp index ed0175bb10..6c0a98cca2 100644 --- a/src/script/descriptor.cpp +++ b/src/script/descriptor.cpp @@ -139,7 +139,7 @@ std::string DescriptorChecksum(const Span<const char>& span) return ret; } -std::string AddChecksum(const std::string& str) { return str + "#" + DescriptorChecksum(MakeSpan(str)); } +std::string AddChecksum(const std::string& str) { return str + "#" + DescriptorChecksum(str); } //////////////////////////////////////////////////////////////////////////// // Internal representation // @@ -190,7 +190,7 @@ class OriginPubkeyProvider final : public PubkeyProvider std::string OriginString() const { - return HexStr(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint)) + FormatHDKeypath(m_origin.path); + return HexStr(m_origin.fingerprint) + FormatHDKeypath(m_origin.path); } public: @@ -235,7 +235,7 @@ public: } bool IsRange() const override { return false; } size_t GetSize() const override { return m_pubkey.size(); } - std::string ToString() const override { return HexStr(m_pubkey.begin(), m_pubkey.end()); } + std::string ToString() const override { return HexStr(m_pubkey); } bool ToPrivateString(const SigningProvider& arg, std::string& ret) const override { CKey key; @@ -583,7 +583,7 @@ class RawDescriptor final : public DescriptorImpl { const CScript m_script; protected: - std::string ToStringExtra() const override { return HexStr(m_script.begin(), m_script.end()); } + std::string ToStringExtra() const override { return HexStr(m_script); } std::vector<CScript> MakeScripts(const std::vector<CPubKey>&, const CScript*, FlatSigningProvider&) const override { return Vector(m_script); } public: RawDescriptor(CScript script) : DescriptorImpl({}, {}, "raw"), m_script(std::move(script)) {} @@ -825,8 +825,9 @@ std::unique_ptr<PubkeyProvider> ParsePubkey(uint32_t key_exp_index, const Span<c return nullptr; } if (origin_split.size() == 1) return ParsePubkeyInner(key_exp_index, origin_split[0], permit_uncompressed, out, error); - if (origin_split[0].size() < 1 || origin_split[0][0] != '[') { - error = strprintf("Key origin start '[ character expected but not found, got '%c' instead", origin_split[0][0]); + if (origin_split[0].empty() || origin_split[0][0] != '[') { + error = strprintf("Key origin start '[ character expected but not found, got '%c' instead", + origin_split[0].empty() ? /** empty, implies split char */ ']' : origin_split[0][0]); return nullptr; } auto slash_split = Split(origin_split[0].subspan(1), '/'); @@ -896,7 +897,7 @@ std::unique_ptr<DescriptorImpl> ParseScript(uint32_t key_exp_index, Span<const c providers.emplace_back(std::move(pk)); key_exp_index++; } - if (providers.size() < 1 || providers.size() > 16) { + if (providers.empty() || providers.size() > 16) { error = strprintf("Cannot have %u keys in multisig; must have between 1 and 16 keys, inclusive", providers.size()); return nullptr; } else if (thres < 1) { @@ -985,15 +986,15 @@ std::unique_ptr<PubkeyProvider> InferPubkey(const CPubKey& pubkey, ParseScriptCo std::unique_ptr<DescriptorImpl> InferScript(const CScript& script, ParseScriptContext ctx, const SigningProvider& provider) { std::vector<std::vector<unsigned char>> data; - txnouttype txntype = Solver(script, data); + TxoutType txntype = Solver(script, data); - if (txntype == TX_PUBKEY) { + if (txntype == TxoutType::PUBKEY) { CPubKey pubkey(data[0].begin(), data[0].end()); if (pubkey.IsValid()) { return MakeUnique<PKDescriptor>(InferPubkey(pubkey, ctx, provider)); } } - if (txntype == TX_PUBKEYHASH) { + if (txntype == TxoutType::PUBKEYHASH) { uint160 hash(data[0]); CKeyID keyid(hash); CPubKey pubkey; @@ -1001,7 +1002,7 @@ std::unique_ptr<DescriptorImpl> InferScript(const CScript& script, ParseScriptCo return MakeUnique<PKHDescriptor>(InferPubkey(pubkey, ctx, provider)); } } - if (txntype == TX_WITNESS_V0_KEYHASH && ctx != ParseScriptContext::P2WSH) { + if (txntype == TxoutType::WITNESS_V0_KEYHASH && ctx != ParseScriptContext::P2WSH) { uint160 hash(data[0]); CKeyID keyid(hash); CPubKey pubkey; @@ -1009,7 +1010,7 @@ std::unique_ptr<DescriptorImpl> InferScript(const CScript& script, ParseScriptCo return MakeUnique<WPKHDescriptor>(InferPubkey(pubkey, ctx, provider)); } } - if (txntype == TX_MULTISIG) { + if (txntype == TxoutType::MULTISIG) { std::vector<std::unique_ptr<PubkeyProvider>> providers; for (size_t i = 1; i + 1 < data.size(); ++i) { CPubKey pubkey(data[i].begin(), data[i].end()); @@ -1017,7 +1018,7 @@ std::unique_ptr<DescriptorImpl> InferScript(const CScript& script, ParseScriptCo } return MakeUnique<MultisigDescriptor>((int)data[0][0], std::move(providers)); } - if (txntype == TX_SCRIPTHASH && ctx == ParseScriptContext::TOP) { + if (txntype == TxoutType::SCRIPTHASH && ctx == ParseScriptContext::TOP) { uint160 hash(data[0]); CScriptID scriptid(hash); CScript subscript; @@ -1026,7 +1027,7 @@ std::unique_ptr<DescriptorImpl> InferScript(const CScript& script, ParseScriptCo if (sub) return MakeUnique<SHDescriptor>(std::move(sub)); } } - if (txntype == TX_WITNESS_V0_SCRIPTHASH && ctx != ParseScriptContext::P2WSH) { + if (txntype == TxoutType::WITNESS_V0_SCRIPTHASH && ctx != ParseScriptContext::P2WSH) { CScriptID scriptid; CRIPEMD160().Write(data[0].data(), data[0].size()).Finalize(scriptid.begin()); CScript subscript; @@ -1087,7 +1088,7 @@ bool CheckChecksum(Span<const char>& sp, bool require_checksum, std::string& err std::unique_ptr<Descriptor> Parse(const std::string& descriptor, FlatSigningProvider& out, std::string& error, bool require_checksum) { - Span<const char> sp(descriptor.data(), descriptor.size()); + Span<const char> sp{descriptor}; if (!CheckChecksum(sp, require_checksum, error)) return nullptr; auto ret = ParseScript(0, sp, ParseScriptContext::TOP, out, error); if (sp.size() == 0 && ret) return std::unique_ptr<Descriptor>(std::move(ret)); @@ -1098,7 +1099,7 @@ std::string GetDescriptorChecksum(const std::string& descriptor) { std::string ret; std::string error; - Span<const char> sp(descriptor.data(), descriptor.size()); + Span<const char> sp{descriptor}; if (!CheckChecksum(sp, false, error, &ret)) return ""; return ret; } diff --git a/src/script/interpreter.cpp b/src/script/interpreter.cpp index 23d5b72a5c..39feb4ccc9 100644 --- a/src/script/interpreter.cpp +++ b/src/script/interpreter.cpp @@ -986,9 +986,9 @@ bool EvalScript(std::vector<std::vector<unsigned char> >& stack, const CScript& else if (opcode == OP_SHA256) CSHA256().Write(vch.data(), vch.size()).Finalize(vchHash.data()); else if (opcode == OP_HASH160) - CHash160().Write(vch.data(), vch.size()).Finalize(vchHash.data()); + CHash160().Write(vch).Finalize(vchHash); else if (opcode == OP_HASH256) - CHash256().Write(vch.data(), vch.size()).Finalize(vchHash.data()); + CHash256().Write(vch).Finalize(vchHash); popstack(stack); stack.push_back(vchHash); } @@ -1522,7 +1522,7 @@ static bool ExecuteWitnessScript(const Span<const valtype>& stack_span, const CS static bool VerifyWitnessProgram(const CScriptWitness& witness, int witversion, const std::vector<unsigned char>& program, unsigned int flags, const BaseSignatureChecker& checker, ScriptError* serror) { CScript scriptPubKey; - Span<const valtype> stack = MakeSpan(witness.stack); + Span<const valtype> stack{witness.stack}; if (witversion == 0) { if (program.size() == WITNESS_V0_SCRIPTHASH_SIZE) { diff --git a/src/script/sigcache.cpp b/src/script/sigcache.cpp index e7b6df3ce8..aaecab1ef2 100644 --- a/src/script/sigcache.cpp +++ b/src/script/sigcache.cpp @@ -11,7 +11,7 @@ #include <util/system.h> #include <cuckoocache.h> -#include <boost/thread.hpp> +#include <boost/thread/shared_mutex.hpp> namespace { /** @@ -23,7 +23,7 @@ class CSignatureCache { private: //! Entries are SHA256(nonce || signature hash || public key || signature): - uint256 nonce; + CSHA256 m_salted_hasher; typedef CuckooCache::cache<uint256, SignatureCacheHasher> map_type; map_type setValid; boost::shared_mutex cs_sigcache; @@ -31,13 +31,19 @@ private: public: CSignatureCache() { - GetRandBytes(nonce.begin(), 32); + uint256 nonce = GetRandHash(); + // We want the nonce to be 64 bytes long to force the hasher to process + // this chunk, which makes later hash computations more efficient. We + // just write our 32-byte entropy twice to fill the 64 bytes. + m_salted_hasher.Write(nonce.begin(), 32); + m_salted_hasher.Write(nonce.begin(), 32); } void ComputeEntry(uint256& entry, const uint256 &hash, const std::vector<unsigned char>& vchSig, const CPubKey& pubkey) { - CSHA256().Write(nonce.begin(), 32).Write(hash.begin(), 32).Write(&pubkey[0], pubkey.size()).Write(&vchSig[0], vchSig.size()).Finalize(entry.begin()); + CSHA256 hasher = m_salted_hasher; + hasher.Write(hash.begin(), 32).Write(&pubkey[0], pubkey.size()).Write(&vchSig[0], vchSig.size()).Finalize(entry.begin()); } bool diff --git a/src/script/sign.cpp b/src/script/sign.cpp index 1e00afcf89..f425215549 100644 --- a/src/script/sign.cpp +++ b/src/script/sign.cpp @@ -92,11 +92,11 @@ static bool CreateSig(const BaseSignatureCreator& creator, SignatureData& sigdat /** * Sign scriptPubKey using signature made with creator. * Signatures are returned in scriptSigRet (or returns false if scriptPubKey can't be signed), - * unless whichTypeRet is TX_SCRIPTHASH, in which case scriptSigRet is the redemption script. + * unless whichTypeRet is TxoutType::SCRIPTHASH, in which case scriptSigRet is the redemption script. * Returns false if scriptPubKey could not be completely satisfied. */ static bool SignStep(const SigningProvider& provider, const BaseSignatureCreator& creator, const CScript& scriptPubKey, - std::vector<valtype>& ret, txnouttype& whichTypeRet, SigVersion sigversion, SignatureData& sigdata) + std::vector<valtype>& ret, TxoutType& whichTypeRet, SigVersion sigversion, SignatureData& sigdata) { CScript scriptRet; uint160 h160; @@ -108,15 +108,15 @@ static bool SignStep(const SigningProvider& provider, const BaseSignatureCreator switch (whichTypeRet) { - case TX_NONSTANDARD: - case TX_NULL_DATA: - case TX_WITNESS_UNKNOWN: + case TxoutType::NONSTANDARD: + case TxoutType::NULL_DATA: + case TxoutType::WITNESS_UNKNOWN: return false; - case TX_PUBKEY: + case TxoutType::PUBKEY: if (!CreateSig(creator, sigdata, provider, sig, CPubKey(vSolutions[0]), scriptPubKey, sigversion)) return false; ret.push_back(std::move(sig)); return true; - case TX_PUBKEYHASH: { + case TxoutType::PUBKEYHASH: { CKeyID keyID = CKeyID(uint160(vSolutions[0])); CPubKey pubkey; if (!GetPubKey(provider, sigdata, keyID, pubkey)) { @@ -129,9 +129,9 @@ static bool SignStep(const SigningProvider& provider, const BaseSignatureCreator ret.push_back(ToByteVector(pubkey)); return true; } - case TX_SCRIPTHASH: + case TxoutType::SCRIPTHASH: h160 = uint160(vSolutions[0]); - if (GetCScript(provider, sigdata, h160, scriptRet)) { + if (GetCScript(provider, sigdata, CScriptID{h160}, scriptRet)) { ret.push_back(std::vector<unsigned char>(scriptRet.begin(), scriptRet.end())); return true; } @@ -139,7 +139,7 @@ static bool SignStep(const SigningProvider& provider, const BaseSignatureCreator sigdata.missing_redeem_script = h160; return false; - case TX_MULTISIG: { + case TxoutType::MULTISIG: { size_t required = vSolutions.front()[0]; ret.push_back(valtype()); // workaround CHECKMULTISIG bug for (size_t i = 1; i < vSolutions.size() - 1; ++i) { @@ -159,13 +159,13 @@ static bool SignStep(const SigningProvider& provider, const BaseSignatureCreator } return ok; } - case TX_WITNESS_V0_KEYHASH: + case TxoutType::WITNESS_V0_KEYHASH: ret.push_back(vSolutions[0]); return true; - case TX_WITNESS_V0_SCRIPTHASH: + case TxoutType::WITNESS_V0_SCRIPTHASH: CRIPEMD160().Write(&vSolutions[0][0], vSolutions[0].size()).Finalize(h160.begin()); - if (GetCScript(provider, sigdata, h160, scriptRet)) { + if (GetCScript(provider, sigdata, CScriptID{h160}, scriptRet)) { ret.push_back(std::vector<unsigned char>(scriptRet.begin(), scriptRet.end())); return true; } @@ -198,44 +198,44 @@ bool ProduceSignature(const SigningProvider& provider, const BaseSignatureCreato if (sigdata.complete) return true; std::vector<valtype> result; - txnouttype whichType; + TxoutType whichType; bool solved = SignStep(provider, creator, fromPubKey, result, whichType, SigVersion::BASE, sigdata); bool P2SH = false; CScript subscript; sigdata.scriptWitness.stack.clear(); - if (solved && whichType == TX_SCRIPTHASH) + if (solved && whichType == TxoutType::SCRIPTHASH) { // Solver returns the subscript that needs to be evaluated; // the final scriptSig is the signatures from that // and then the serialized subscript: subscript = CScript(result[0].begin(), result[0].end()); sigdata.redeem_script = subscript; - solved = solved && SignStep(provider, creator, subscript, result, whichType, SigVersion::BASE, sigdata) && whichType != TX_SCRIPTHASH; + solved = solved && SignStep(provider, creator, subscript, result, whichType, SigVersion::BASE, sigdata) && whichType != TxoutType::SCRIPTHASH; P2SH = true; } - if (solved && whichType == TX_WITNESS_V0_KEYHASH) + if (solved && whichType == TxoutType::WITNESS_V0_KEYHASH) { CScript witnessscript; witnessscript << OP_DUP << OP_HASH160 << ToByteVector(result[0]) << OP_EQUALVERIFY << OP_CHECKSIG; - txnouttype subType; + TxoutType subType; solved = solved && SignStep(provider, creator, witnessscript, result, subType, SigVersion::WITNESS_V0, sigdata); sigdata.scriptWitness.stack = result; sigdata.witness = true; result.clear(); } - else if (solved && whichType == TX_WITNESS_V0_SCRIPTHASH) + else if (solved && whichType == TxoutType::WITNESS_V0_SCRIPTHASH) { CScript witnessscript(result[0].begin(), result[0].end()); sigdata.witness_script = witnessscript; - txnouttype subType; - solved = solved && SignStep(provider, creator, witnessscript, result, subType, SigVersion::WITNESS_V0, sigdata) && subType != TX_SCRIPTHASH && subType != TX_WITNESS_V0_SCRIPTHASH && subType != TX_WITNESS_V0_KEYHASH; + TxoutType subType; + solved = solved && SignStep(provider, creator, witnessscript, result, subType, SigVersion::WITNESS_V0, sigdata) && subType != TxoutType::SCRIPTHASH && subType != TxoutType::WITNESS_V0_SCRIPTHASH && subType != TxoutType::WITNESS_V0_KEYHASH; result.push_back(std::vector<unsigned char>(witnessscript.begin(), witnessscript.end())); sigdata.scriptWitness.stack = result; sigdata.witness = true; result.clear(); - } else if (solved && whichType == TX_WITNESS_UNKNOWN) { + } else if (solved && whichType == TxoutType::WITNESS_UNKNOWN) { sigdata.witness = true; } @@ -301,11 +301,11 @@ SignatureData DataFromTransaction(const CMutableTransaction& tx, unsigned int nI // Get scripts std::vector<std::vector<unsigned char>> solutions; - txnouttype script_type = Solver(txout.scriptPubKey, solutions); + TxoutType script_type = Solver(txout.scriptPubKey, solutions); SigVersion sigversion = SigVersion::BASE; CScript next_script = txout.scriptPubKey; - if (script_type == TX_SCRIPTHASH && !stack.script.empty() && !stack.script.back().empty()) { + if (script_type == TxoutType::SCRIPTHASH && !stack.script.empty() && !stack.script.back().empty()) { // Get the redeemScript CScript redeem_script(stack.script.back().begin(), stack.script.back().end()); data.redeem_script = redeem_script; @@ -315,7 +315,7 @@ SignatureData DataFromTransaction(const CMutableTransaction& tx, unsigned int nI script_type = Solver(next_script, solutions); stack.script.pop_back(); } - if (script_type == TX_WITNESS_V0_SCRIPTHASH && !stack.witness.empty() && !stack.witness.back().empty()) { + if (script_type == TxoutType::WITNESS_V0_SCRIPTHASH && !stack.witness.empty() && !stack.witness.back().empty()) { // Get the witnessScript CScript witness_script(stack.witness.back().begin(), stack.witness.back().end()); data.witness_script = witness_script; @@ -328,7 +328,7 @@ SignatureData DataFromTransaction(const CMutableTransaction& tx, unsigned int nI stack.witness.clear(); sigversion = SigVersion::WITNESS_V0; } - if (script_type == TX_MULTISIG && !stack.script.empty()) { + if (script_type == TxoutType::MULTISIG && !stack.script.empty()) { // Build a map of pubkey -> signature by matching sigs to pubkeys: assert(solutions.size() > 1); unsigned int num_pubkeys = solutions.size()-2; @@ -454,13 +454,13 @@ bool IsSegWitOutput(const SigningProvider& provider, const CScript& script) { std::vector<valtype> solutions; auto whichtype = Solver(script, solutions); - if (whichtype == TX_WITNESS_V0_SCRIPTHASH || whichtype == TX_WITNESS_V0_KEYHASH || whichtype == TX_WITNESS_UNKNOWN) return true; - if (whichtype == TX_SCRIPTHASH) { + if (whichtype == TxoutType::WITNESS_V0_SCRIPTHASH || whichtype == TxoutType::WITNESS_V0_KEYHASH || whichtype == TxoutType::WITNESS_UNKNOWN) return true; + if (whichtype == TxoutType::SCRIPTHASH) { auto h160 = uint160(solutions[0]); CScript subscript; - if (provider.GetCScript(h160, subscript)) { + if (provider.GetCScript(CScriptID{h160}, subscript)) { whichtype = Solver(subscript, solutions); - if (whichtype == TX_WITNESS_V0_SCRIPTHASH || whichtype == TX_WITNESS_V0_KEYHASH || whichtype == TX_WITNESS_UNKNOWN) return true; + if (whichtype == TxoutType::WITNESS_V0_SCRIPTHASH || whichtype == TxoutType::WITNESS_V0_KEYHASH || whichtype == TxoutType::WITNESS_UNKNOWN) return true; } } return false; diff --git a/src/script/signingprovider.cpp b/src/script/signingprovider.cpp index 01757e2f65..2d8dc7d471 100644 --- a/src/script/signingprovider.cpp +++ b/src/script/signingprovider.cpp @@ -180,10 +180,10 @@ CKeyID GetKeyForDestination(const SigningProvider& store, const CTxDestination& // Only supports destinations which map to single public keys, i.e. P2PKH, // P2WPKH, and P2SH-P2WPKH. if (auto id = boost::get<PKHash>(&dest)) { - return CKeyID(*id); + return ToKeyID(*id); } if (auto witness_id = boost::get<WitnessV0KeyHash>(&dest)) { - return CKeyID(*witness_id); + return ToKeyID(*witness_id); } if (auto script_hash = boost::get<ScriptHash>(&dest)) { CScript script; @@ -191,7 +191,7 @@ CKeyID GetKeyForDestination(const SigningProvider& store, const CTxDestination& CTxDestination inner_dest; if (store.GetCScript(script_id, script) && ExtractDestination(script, inner_dest)) { if (auto inner_witness_id = boost::get<WitnessV0KeyHash>(&inner_dest)) { - return CKeyID(*inner_witness_id); + return ToKeyID(*inner_witness_id); } } } diff --git a/src/script/standard.cpp b/src/script/standard.cpp index c90c2c24a0..3a4882f280 100644 --- a/src/script/standard.cpp +++ b/src/script/standard.cpp @@ -16,31 +16,47 @@ typedef std::vector<unsigned char> valtype; bool fAcceptDatacarrier = DEFAULT_ACCEPT_DATACARRIER; unsigned nMaxDatacarrierBytes = MAX_OP_RETURN_RELAY; -CScriptID::CScriptID(const CScript& in) : uint160(Hash160(in.begin(), in.end())) {} +CScriptID::CScriptID(const CScript& in) : BaseHash(Hash160(in)) {} +CScriptID::CScriptID(const ScriptHash& in) : BaseHash(static_cast<uint160>(in)) {} -ScriptHash::ScriptHash(const CScript& in) : uint160(Hash160(in.begin(), in.end())) {} +ScriptHash::ScriptHash(const CScript& in) : BaseHash(Hash160(in)) {} +ScriptHash::ScriptHash(const CScriptID& in) : BaseHash(static_cast<uint160>(in)) {} -PKHash::PKHash(const CPubKey& pubkey) : uint160(pubkey.GetID()) {} +PKHash::PKHash(const CPubKey& pubkey) : BaseHash(pubkey.GetID()) {} +PKHash::PKHash(const CKeyID& pubkey_id) : BaseHash(pubkey_id) {} + +WitnessV0KeyHash::WitnessV0KeyHash(const CPubKey& pubkey) : BaseHash(pubkey.GetID()) {} +WitnessV0KeyHash::WitnessV0KeyHash(const PKHash& pubkey_hash) : BaseHash(static_cast<uint160>(pubkey_hash)) {} + +CKeyID ToKeyID(const PKHash& key_hash) +{ + return CKeyID{static_cast<uint160>(key_hash)}; +} + +CKeyID ToKeyID(const WitnessV0KeyHash& key_hash) +{ + return CKeyID{static_cast<uint160>(key_hash)}; +} WitnessV0ScriptHash::WitnessV0ScriptHash(const CScript& in) { CSHA256().Write(in.data(), in.size()).Finalize(begin()); } -std::string GetTxnOutputType(txnouttype t) +std::string GetTxnOutputType(TxoutType t) { switch (t) { - case TX_NONSTANDARD: return "nonstandard"; - case TX_PUBKEY: return "pubkey"; - case TX_PUBKEYHASH: return "pubkeyhash"; - case TX_SCRIPTHASH: return "scripthash"; - case TX_MULTISIG: return "multisig"; - case TX_NULL_DATA: return "nulldata"; - case TX_WITNESS_V0_KEYHASH: return "witness_v0_keyhash"; - case TX_WITNESS_V0_SCRIPTHASH: return "witness_v0_scripthash"; - case TX_WITNESS_UNKNOWN: return "witness_unknown"; - } + case TxoutType::NONSTANDARD: return "nonstandard"; + case TxoutType::PUBKEY: return "pubkey"; + case TxoutType::PUBKEYHASH: return "pubkeyhash"; + case TxoutType::SCRIPTHASH: return "scripthash"; + case TxoutType::MULTISIG: return "multisig"; + case TxoutType::NULL_DATA: return "nulldata"; + case TxoutType::WITNESS_V0_KEYHASH: return "witness_v0_keyhash"; + case TxoutType::WITNESS_V0_SCRIPTHASH: return "witness_v0_scripthash"; + case TxoutType::WITNESS_UNKNOWN: return "witness_unknown"; + } // no default case, so the compiler can warn about missing cases assert(false); } @@ -90,7 +106,7 @@ static bool MatchMultisig(const CScript& script, unsigned int& required, std::ve return (it + 1 == script.end()); } -txnouttype Solver(const CScript& scriptPubKey, std::vector<std::vector<unsigned char>>& vSolutionsRet) +TxoutType Solver(const CScript& scriptPubKey, std::vector<std::vector<unsigned char>>& vSolutionsRet) { vSolutionsRet.clear(); @@ -100,7 +116,7 @@ txnouttype Solver(const CScript& scriptPubKey, std::vector<std::vector<unsigned { std::vector<unsigned char> hashBytes(scriptPubKey.begin()+2, scriptPubKey.begin()+22); vSolutionsRet.push_back(hashBytes); - return TX_SCRIPTHASH; + return TxoutType::SCRIPTHASH; } int witnessversion; @@ -108,18 +124,18 @@ txnouttype Solver(const CScript& scriptPubKey, std::vector<std::vector<unsigned if (scriptPubKey.IsWitnessProgram(witnessversion, witnessprogram)) { if (witnessversion == 0 && witnessprogram.size() == WITNESS_V0_KEYHASH_SIZE) { vSolutionsRet.push_back(witnessprogram); - return TX_WITNESS_V0_KEYHASH; + return TxoutType::WITNESS_V0_KEYHASH; } if (witnessversion == 0 && witnessprogram.size() == WITNESS_V0_SCRIPTHASH_SIZE) { vSolutionsRet.push_back(witnessprogram); - return TX_WITNESS_V0_SCRIPTHASH; + return TxoutType::WITNESS_V0_SCRIPTHASH; } if (witnessversion != 0) { vSolutionsRet.push_back(std::vector<unsigned char>{(unsigned char)witnessversion}); vSolutionsRet.push_back(std::move(witnessprogram)); - return TX_WITNESS_UNKNOWN; + return TxoutType::WITNESS_UNKNOWN; } - return TX_NONSTANDARD; + return TxoutType::NONSTANDARD; } // Provably prunable, data-carrying output @@ -128,18 +144,18 @@ txnouttype Solver(const CScript& scriptPubKey, std::vector<std::vector<unsigned // byte passes the IsPushOnly() test we don't care what exactly is in the // script. if (scriptPubKey.size() >= 1 && scriptPubKey[0] == OP_RETURN && scriptPubKey.IsPushOnly(scriptPubKey.begin()+1)) { - return TX_NULL_DATA; + return TxoutType::NULL_DATA; } std::vector<unsigned char> data; if (MatchPayToPubkey(scriptPubKey, data)) { vSolutionsRet.push_back(std::move(data)); - return TX_PUBKEY; + return TxoutType::PUBKEY; } if (MatchPayToPubkeyHash(scriptPubKey, data)) { vSolutionsRet.push_back(std::move(data)); - return TX_PUBKEYHASH; + return TxoutType::PUBKEYHASH; } unsigned int required; @@ -148,19 +164,19 @@ txnouttype Solver(const CScript& scriptPubKey, std::vector<std::vector<unsigned vSolutionsRet.push_back({static_cast<unsigned char>(required)}); // safe as required is in range 1..16 vSolutionsRet.insert(vSolutionsRet.end(), keys.begin(), keys.end()); vSolutionsRet.push_back({static_cast<unsigned char>(keys.size())}); // safe as size is in range 1..16 - return TX_MULTISIG; + return TxoutType::MULTISIG; } vSolutionsRet.clear(); - return TX_NONSTANDARD; + return TxoutType::NONSTANDARD; } bool ExtractDestination(const CScript& scriptPubKey, CTxDestination& addressRet) { std::vector<valtype> vSolutions; - txnouttype whichType = Solver(scriptPubKey, vSolutions); + TxoutType whichType = Solver(scriptPubKey, vSolutions); - if (whichType == TX_PUBKEY) { + if (whichType == TxoutType::PUBKEY) { CPubKey pubKey(vSolutions[0]); if (!pubKey.IsValid()) return false; @@ -168,26 +184,26 @@ bool ExtractDestination(const CScript& scriptPubKey, CTxDestination& addressRet) addressRet = PKHash(pubKey); return true; } - else if (whichType == TX_PUBKEYHASH) + else if (whichType == TxoutType::PUBKEYHASH) { addressRet = PKHash(uint160(vSolutions[0])); return true; } - else if (whichType == TX_SCRIPTHASH) + else if (whichType == TxoutType::SCRIPTHASH) { addressRet = ScriptHash(uint160(vSolutions[0])); return true; - } else if (whichType == TX_WITNESS_V0_KEYHASH) { + } else if (whichType == TxoutType::WITNESS_V0_KEYHASH) { WitnessV0KeyHash hash; std::copy(vSolutions[0].begin(), vSolutions[0].end(), hash.begin()); addressRet = hash; return true; - } else if (whichType == TX_WITNESS_V0_SCRIPTHASH) { + } else if (whichType == TxoutType::WITNESS_V0_SCRIPTHASH) { WitnessV0ScriptHash hash; std::copy(vSolutions[0].begin(), vSolutions[0].end(), hash.begin()); addressRet = hash; return true; - } else if (whichType == TX_WITNESS_UNKNOWN) { + } else if (whichType == TxoutType::WITNESS_UNKNOWN) { WitnessUnknown unk; unk.version = vSolutions[0][0]; std::copy(vSolutions[1].begin(), vSolutions[1].end(), unk.program); @@ -199,19 +215,19 @@ bool ExtractDestination(const CScript& scriptPubKey, CTxDestination& addressRet) return false; } -bool ExtractDestinations(const CScript& scriptPubKey, txnouttype& typeRet, std::vector<CTxDestination>& addressRet, int& nRequiredRet) +bool ExtractDestinations(const CScript& scriptPubKey, TxoutType& typeRet, std::vector<CTxDestination>& addressRet, int& nRequiredRet) { addressRet.clear(); std::vector<valtype> vSolutions; typeRet = Solver(scriptPubKey, vSolutions); - if (typeRet == TX_NONSTANDARD) { + if (typeRet == TxoutType::NONSTANDARD) { return false; - } else if (typeRet == TX_NULL_DATA) { + } else if (typeRet == TxoutType::NULL_DATA) { // This is data, not addresses return false; } - if (typeRet == TX_MULTISIG) + if (typeRet == TxoutType::MULTISIG) { nRequiredRet = vSolutions.front()[0]; for (unsigned int i = 1; i < vSolutions.size()-1; i++) @@ -241,59 +257,44 @@ bool ExtractDestinations(const CScript& scriptPubKey, txnouttype& typeRet, std:: namespace { -class CScriptVisitor : public boost::static_visitor<bool> +class CScriptVisitor : public boost::static_visitor<CScript> { -private: - CScript *script; public: - explicit CScriptVisitor(CScript *scriptin) { script = scriptin; } - - bool operator()(const CNoDestination &dest) const { - script->clear(); - return false; + CScript operator()(const CNoDestination& dest) const + { + return CScript(); } - bool operator()(const PKHash &keyID) const { - script->clear(); - *script << OP_DUP << OP_HASH160 << ToByteVector(keyID) << OP_EQUALVERIFY << OP_CHECKSIG; - return true; + CScript operator()(const PKHash& keyID) const + { + return CScript() << OP_DUP << OP_HASH160 << ToByteVector(keyID) << OP_EQUALVERIFY << OP_CHECKSIG; } - bool operator()(const ScriptHash &scriptID) const { - script->clear(); - *script << OP_HASH160 << ToByteVector(scriptID) << OP_EQUAL; - return true; + CScript operator()(const ScriptHash& scriptID) const + { + return CScript() << OP_HASH160 << ToByteVector(scriptID) << OP_EQUAL; } - bool operator()(const WitnessV0KeyHash& id) const + CScript operator()(const WitnessV0KeyHash& id) const { - script->clear(); - *script << OP_0 << ToByteVector(id); - return true; + return CScript() << OP_0 << ToByteVector(id); } - bool operator()(const WitnessV0ScriptHash& id) const + CScript operator()(const WitnessV0ScriptHash& id) const { - script->clear(); - *script << OP_0 << ToByteVector(id); - return true; + return CScript() << OP_0 << ToByteVector(id); } - bool operator()(const WitnessUnknown& id) const + CScript operator()(const WitnessUnknown& id) const { - script->clear(); - *script << CScript::EncodeOP_N(id.version) << std::vector<unsigned char>(id.program, id.program + id.length); - return true; + return CScript() << CScript::EncodeOP_N(id.version) << std::vector<unsigned char>(id.program, id.program + id.length); } }; } // namespace CScript GetScriptForDestination(const CTxDestination& dest) { - CScript script; - - boost::apply_visitor(CScriptVisitor(&script), dest); - return script; + return boost::apply_visitor(CScriptVisitor(), dest); } CScript GetScriptForRawPubKey(const CPubKey& pubKey) @@ -315,11 +316,11 @@ CScript GetScriptForMultisig(int nRequired, const std::vector<CPubKey>& keys) CScript GetScriptForWitness(const CScript& redeemscript) { std::vector<std::vector<unsigned char> > vSolutions; - txnouttype typ = Solver(redeemscript, vSolutions); - if (typ == TX_PUBKEY) { - return GetScriptForDestination(WitnessV0KeyHash(Hash160(vSolutions[0].begin(), vSolutions[0].end()))); - } else if (typ == TX_PUBKEYHASH) { - return GetScriptForDestination(WitnessV0KeyHash(vSolutions[0])); + TxoutType typ = Solver(redeemscript, vSolutions); + if (typ == TxoutType::PUBKEY) { + return GetScriptForDestination(WitnessV0KeyHash(Hash160(vSolutions[0]))); + } else if (typ == TxoutType::PUBKEYHASH) { + return GetScriptForDestination(WitnessV0KeyHash(uint160{vSolutions[0]})); } return GetScriptForDestination(WitnessV0ScriptHash(redeemscript)); } diff --git a/src/script/standard.h b/src/script/standard.h index 2929425670..992e37675f 100644 --- a/src/script/standard.h +++ b/src/script/standard.h @@ -18,14 +18,80 @@ static const bool DEFAULT_ACCEPT_DATACARRIER = true; class CKeyID; class CScript; +struct ScriptHash; + +template<typename HashType> +class BaseHash +{ +protected: + HashType m_hash; + +public: + BaseHash() : m_hash() {} + BaseHash(const HashType& in) : m_hash(in) {} + + unsigned char* begin() + { + return m_hash.begin(); + } + + const unsigned char* begin() const + { + return m_hash.begin(); + } + + unsigned char* end() + { + return m_hash.end(); + } + + const unsigned char* end() const + { + return m_hash.end(); + } + + operator std::vector<unsigned char>() const + { + return std::vector<unsigned char>{m_hash.begin(), m_hash.end()}; + } + + std::string ToString() const + { + return m_hash.ToString(); + } + + bool operator==(const BaseHash<HashType>& other) const noexcept + { + return m_hash == other.m_hash; + } + + bool operator!=(const BaseHash<HashType>& other) const noexcept + { + return !(m_hash == other.m_hash); + } + + bool operator<(const BaseHash<HashType>& other) const noexcept + { + return m_hash < other.m_hash; + } + + size_t size() const + { + return m_hash.size(); + } + + unsigned char* data() { return m_hash.data(); } + const unsigned char* data() const { return m_hash.data(); } +}; /** A reference to a CScript: the Hash160 of its serialization (see script.h) */ -class CScriptID : public uint160 +class CScriptID : public BaseHash<uint160> { public: - CScriptID() : uint160() {} + CScriptID() : BaseHash() {} explicit CScriptID(const CScript& in); - CScriptID(const uint160& in) : uint160(in) {} + explicit CScriptID(const uint160& in) : BaseHash(in) {} + explicit CScriptID(const ScriptHash& in); }; /** @@ -36,11 +102,11 @@ static const unsigned int MAX_OP_RETURN_RELAY = 83; /** * A data carrying output is an unspendable output containing data. The script - * type is designated as TX_NULL_DATA. + * type is designated as TxoutType::NULL_DATA. */ extern bool fAcceptDatacarrier; -/** Maximum size of TX_NULL_DATA scripts that this node considers standard. */ +/** Maximum size of TxoutType::NULL_DATA scripts that this node considers standard. */ extern unsigned nMaxDatacarrierBytes; /** @@ -53,18 +119,17 @@ extern unsigned nMaxDatacarrierBytes; */ static const unsigned int MANDATORY_SCRIPT_VERIFY_FLAGS = SCRIPT_VERIFY_P2SH; -enum txnouttype -{ - TX_NONSTANDARD, +enum class TxoutType { + NONSTANDARD, // 'standard' transaction types: - TX_PUBKEY, - TX_PUBKEYHASH, - TX_SCRIPTHASH, - TX_MULTISIG, - TX_NULL_DATA, //!< unspendable OP_RETURN script that carries data - TX_WITNESS_V0_SCRIPTHASH, - TX_WITNESS_V0_KEYHASH, - TX_WITNESS_UNKNOWN, //!< Only for Witness versions not already defined above + PUBKEY, + PUBKEYHASH, + SCRIPTHASH, + MULTISIG, + NULL_DATA, //!< unspendable OP_RETURN script that carries data + WITNESS_V0_SCRIPTHASH, + WITNESS_V0_KEYHASH, + WITNESS_UNKNOWN, //!< Only for Witness versions not already defined above }; class CNoDestination { @@ -73,41 +138,44 @@ public: friend bool operator<(const CNoDestination &a, const CNoDestination &b) { return true; } }; -struct PKHash : public uint160 +struct PKHash : public BaseHash<uint160> { - PKHash() : uint160() {} - explicit PKHash(const uint160& hash) : uint160(hash) {} + PKHash() : BaseHash() {} + explicit PKHash(const uint160& hash) : BaseHash(hash) {} explicit PKHash(const CPubKey& pubkey); - using uint160::uint160; + explicit PKHash(const CKeyID& pubkey_id); }; +CKeyID ToKeyID(const PKHash& key_hash); struct WitnessV0KeyHash; -struct ScriptHash : public uint160 +struct ScriptHash : public BaseHash<uint160> { - ScriptHash() : uint160() {} + ScriptHash() : BaseHash() {} // These don't do what you'd expect. // Use ScriptHash(GetScriptForDestination(...)) instead. explicit ScriptHash(const WitnessV0KeyHash& hash) = delete; explicit ScriptHash(const PKHash& hash) = delete; - explicit ScriptHash(const uint160& hash) : uint160(hash) {} + + explicit ScriptHash(const uint160& hash) : BaseHash(hash) {} explicit ScriptHash(const CScript& script); - using uint160::uint160; + explicit ScriptHash(const CScriptID& script); }; -struct WitnessV0ScriptHash : public uint256 +struct WitnessV0ScriptHash : public BaseHash<uint256> { - WitnessV0ScriptHash() : uint256() {} - explicit WitnessV0ScriptHash(const uint256& hash) : uint256(hash) {} + WitnessV0ScriptHash() : BaseHash() {} + explicit WitnessV0ScriptHash(const uint256& hash) : BaseHash(hash) {} explicit WitnessV0ScriptHash(const CScript& script); - using uint256::uint256; }; -struct WitnessV0KeyHash : public uint160 +struct WitnessV0KeyHash : public BaseHash<uint160> { - WitnessV0KeyHash() : uint160() {} - explicit WitnessV0KeyHash(const uint160& hash) : uint160(hash) {} - using uint160::uint160; + WitnessV0KeyHash() : BaseHash() {} + explicit WitnessV0KeyHash(const uint160& hash) : BaseHash(hash) {} + explicit WitnessV0KeyHash(const CPubKey& pubkey); + explicit WitnessV0KeyHash(const PKHash& pubkey_hash); }; +CKeyID ToKeyID(const WitnessV0KeyHash& key_hash); //! CTxDestination subtype to encode any future Witness version struct WitnessUnknown @@ -134,11 +202,11 @@ struct WitnessUnknown /** * A txout script template with a specific destination. It is either: * * CNoDestination: no destination set - * * PKHash: TX_PUBKEYHASH destination (P2PKH) - * * ScriptHash: TX_SCRIPTHASH destination (P2SH) - * * WitnessV0ScriptHash: TX_WITNESS_V0_SCRIPTHASH destination (P2WSH) - * * WitnessV0KeyHash: TX_WITNESS_V0_KEYHASH destination (P2WPKH) - * * WitnessUnknown: TX_WITNESS_UNKNOWN destination (P2W???) + * * PKHash: TxoutType::PUBKEYHASH destination (P2PKH) + * * ScriptHash: TxoutType::SCRIPTHASH destination (P2SH) + * * WitnessV0ScriptHash: TxoutType::WITNESS_V0_SCRIPTHASH destination (P2WSH) + * * WitnessV0KeyHash: TxoutType::WITNESS_V0_KEYHASH destination (P2WPKH) + * * WitnessUnknown: TxoutType::WITNESS_UNKNOWN destination (P2W???) * A CTxDestination is the internal data type encoded in a bitcoin address */ typedef boost::variant<CNoDestination, PKHash, ScriptHash, WitnessV0ScriptHash, WitnessV0KeyHash, WitnessUnknown> CTxDestination; @@ -146,8 +214,8 @@ typedef boost::variant<CNoDestination, PKHash, ScriptHash, WitnessV0ScriptHash, /** Check whether a CTxDestination is a CNoDestination. */ bool IsValidDestination(const CTxDestination& dest); -/** Get the name of a txnouttype as a C string, or nullptr if unknown. */ -std::string GetTxnOutputType(txnouttype t); +/** Get the name of a TxoutType as a string */ +std::string GetTxnOutputType(TxoutType t); /** * Parse a scriptPubKey and identify script type for standard scripts. If @@ -157,9 +225,9 @@ std::string GetTxnOutputType(txnouttype t); * * @param[in] scriptPubKey Script to parse * @param[out] vSolutionsRet Vector of parsed pubkeys and hashes - * @return The script type. TX_NONSTANDARD represents a failed solve. + * @return The script type. TxoutType::NONSTANDARD represents a failed solve. */ -txnouttype Solver(const CScript& scriptPubKey, std::vector<std::vector<unsigned char>>& vSolutionsRet); +TxoutType Solver(const CScript& scriptPubKey, std::vector<std::vector<unsigned char>>& vSolutionsRet); /** * Parse a standard scriptPubKey for the destination address. Assigns result to @@ -180,7 +248,7 @@ bool ExtractDestination(const CScript& scriptPubKey, CTxDestination& addressRet) * encodable as an address) with key identifiers (of keys involved in a * CScript), and its use should be phased out. */ -bool ExtractDestinations(const CScript& scriptPubKey, txnouttype& typeRet, std::vector<CTxDestination>& addressRet, int& nRequiredRet); +bool ExtractDestinations(const CScript& scriptPubKey, TxoutType& typeRet, std::vector<CTxDestination>& addressRet, int& nRequiredRet); /** * Generate a Bitcoin scriptPubKey for the given CTxDestination. Returns a P2PKH diff --git a/src/secp256k1/.gitignore b/src/secp256k1/.gitignore index 55d325aeef..cb4331aa90 100644 --- a/src/secp256k1/.gitignore +++ b/src/secp256k1/.gitignore @@ -9,6 +9,7 @@ bench_internal tests exhaustive_tests gen_context +valgrind_ctime_test *.exe *.so *.a diff --git a/src/secp256k1/.travis.yml b/src/secp256k1/.travis.yml index 74f658f4d1..a6ad6fb27e 100644 --- a/src/secp256k1/.travis.yml +++ b/src/secp256k1/.travis.yml @@ -1,18 +1,23 @@ language: c -os: linux +os: + - linux + - osx + +dist: bionic +# Valgrind currently supports upto macOS 10.13, the latest xcode of that version is 10.1 +osx_image: xcode10.1 addons: apt: - packages: libgmp-dev + packages: + - libgmp-dev + - valgrind + - libtool-bin compiler: - clang - gcc -cache: - directories: - - src/java/guava/ env: global: - - FIELD=auto BIGNUM=auto SCALAR=auto ENDOMORPHISM=no STATICPRECOMPUTATION=yes ASM=no BUILD=check EXTRAFLAGS= HOST= ECDH=no RECOVERY=no EXPERIMENTAL=no JNI=no - - GUAVA_URL=https://search.maven.org/remotecontent?filepath=com/google/guava/guava/18.0/guava-18.0.jar GUAVA_JAR=src/java/guava/guava-18.0.jar + - FIELD=auto BIGNUM=auto SCALAR=auto ENDOMORPHISM=no STATICPRECOMPUTATION=yes ECMULTGENPRECISION=auto ASM=no BUILD=check EXTRAFLAGS= HOST= ECDH=no RECOVERY=no EXPERIMENTAL=no CTIMETEST=yes BENCH=yes ITERS=2 matrix: - SCALAR=32bit RECOVERY=yes - SCALAR=32bit FIELD=32bit ECDH=yes EXPERIMENTAL=yes @@ -26,43 +31,78 @@ env: - BIGNUM=no - BIGNUM=no ENDOMORPHISM=yes RECOVERY=yes EXPERIMENTAL=yes - BIGNUM=no STATICPRECOMPUTATION=no - - BUILD=distcheck - - EXTRAFLAGS=CPPFLAGS=-DDETERMINISTIC - - EXTRAFLAGS=CFLAGS=-O0 - - BUILD=check-java JNI=yes ECDH=yes EXPERIMENTAL=yes + - BUILD=distcheck CTIMETEST= BENCH= + - CPPFLAGS=-DDETERMINISTIC + - CFLAGS=-O0 CTIMETEST= + - ECMULTGENPRECISION=2 + - ECMULTGENPRECISION=8 + - VALGRIND=yes ENDOMORPHISM=yes BIGNUM=no ASM=x86_64 EXPERIMENTAL=yes ECDH=yes RECOVERY=yes EXTRAFLAGS="--disable-openssl-tests" CPPFLAGS=-DVALGRIND BUILD= + - VALGRIND=yes BIGNUM=no ASM=x86_64 EXPERIMENTAL=yes ECDH=yes RECOVERY=yes EXTRAFLAGS="--disable-openssl-tests" CPPFLAGS=-DVALGRIND BUILD= matrix: fast_finish: true include: - compiler: clang + os: linux env: HOST=i686-linux-gnu ENDOMORPHISM=yes addons: apt: packages: - gcc-multilib - libgmp-dev:i386 + - valgrind + - libtool-bin + - libc6-dbg:i386 - compiler: clang env: HOST=i686-linux-gnu + os: linux addons: apt: packages: - gcc-multilib + - valgrind + - libtool-bin + - libc6-dbg:i386 - compiler: gcc env: HOST=i686-linux-gnu ENDOMORPHISM=yes + os: linux addons: apt: packages: - gcc-multilib + - valgrind + - libtool-bin + - libc6-dbg:i386 - compiler: gcc + os: linux env: HOST=i686-linux-gnu addons: apt: packages: - gcc-multilib - libgmp-dev:i386 -before_install: mkdir -p `dirname $GUAVA_JAR` -install: if [ ! -f $GUAVA_JAR ]; then wget $GUAVA_URL -O $GUAVA_JAR; fi + - valgrind + - libtool-bin + - libc6-dbg:i386 + +# We use this to install macOS dependencies instead of the built in `homebrew` plugin, +# because in xcode earlier than 11 they have a bug requiring updating the system which overall takes ~8 minutes. +# https://travis-ci.community/t/macos-build-fails-because-of-homebrew-bundle-unknown-command/7296 +before_install: + - if [ "${TRAVIS_OS_NAME}" = "osx" ]; then HOMEBREW_NO_AUTO_UPDATE=1 brew install gmp valgrind gcc@9; fi + before_script: ./autogen.sh + +# travis auto terminates jobs that go for 10 minutes without printing to stdout, but travis_wait doesn't work well with forking programs like valgrind (https://docs.travis-ci.com/user/common-build-problems/#build-times-out-because-no-output-was-received https://github.com/bitcoin-core/secp256k1/pull/750#issuecomment-623476860) script: - - if [ -n "$HOST" ]; then export USE_HOST="--host=$HOST"; fi - - if [ "x$HOST" = "xi686-linux-gnu" ]; then export CC="$CC -m32"; fi - - ./configure --enable-experimental=$EXPERIMENTAL --enable-endomorphism=$ENDOMORPHISM --with-field=$FIELD --with-bignum=$BIGNUM --with-scalar=$SCALAR --enable-ecmult-static-precomputation=$STATICPRECOMPUTATION --enable-module-ecdh=$ECDH --enable-module-recovery=$RECOVERY --enable-jni=$JNI $EXTRAFLAGS $USE_HOST && make -j2 $BUILD + - function keep_alive() { while true; do echo -en "\a"; sleep 60; done } + - keep_alive & + - ./contrib/travis.sh + - kill %keep_alive + +after_script: + - cat ./tests.log + - cat ./exhaustive_tests.log + - cat ./valgrind_ctime_test.log + - cat ./bench.log + - $CC --version + - valgrind --version diff --git a/src/secp256k1/Makefile.am b/src/secp256k1/Makefile.am index 9e5b7dcce0..d8c1c79e8c 100644 --- a/src/secp256k1/Makefile.am +++ b/src/secp256k1/Makefile.am @@ -1,13 +1,8 @@ ACLOCAL_AMFLAGS = -I build-aux/m4 lib_LTLIBRARIES = libsecp256k1.la -if USE_JNI -JNI_LIB = libsecp256k1_jni.la -noinst_LTLIBRARIES = $(JNI_LIB) -else -JNI_LIB = -endif include_HEADERS = include/secp256k1.h +include_HEADERS += include/secp256k1_preallocated.h noinst_HEADERS = noinst_HEADERS += src/scalar.h noinst_HEADERS += src/scalar_4x64.h @@ -39,8 +34,6 @@ noinst_HEADERS += src/field_5x52.h noinst_HEADERS += src/field_5x52_impl.h noinst_HEADERS += src/field_5x52_int128_impl.h noinst_HEADERS += src/field_5x52_asm_impl.h -noinst_HEADERS += src/java/org_bitcoin_NativeSecp256k1.h -noinst_HEADERS += src/java/org_bitcoin_Secp256k1Context.h noinst_HEADERS += src/util.h noinst_HEADERS += src/scratch.h noinst_HEADERS += src/scratch_impl.h @@ -74,16 +67,19 @@ endif libsecp256k1_la_SOURCES = src/secp256k1.c libsecp256k1_la_CPPFLAGS = -DSECP256K1_BUILD -I$(top_srcdir)/include -I$(top_srcdir)/src $(SECP_INCLUDES) -libsecp256k1_la_LIBADD = $(JNI_LIB) $(SECP_LIBS) $(COMMON_LIB) +libsecp256k1_la_LIBADD = $(SECP_LIBS) $(COMMON_LIB) -libsecp256k1_jni_la_SOURCES = src/java/org_bitcoin_NativeSecp256k1.c src/java/org_bitcoin_Secp256k1Context.c -libsecp256k1_jni_la_CPPFLAGS = -DSECP256K1_BUILD $(JNI_INCLUDES) +if VALGRIND_ENABLED +libsecp256k1_la_CPPFLAGS += -DVALGRIND +endif noinst_PROGRAMS = if USE_BENCHMARK noinst_PROGRAMS += bench_verify bench_sign bench_internal bench_ecmult bench_verify_SOURCES = src/bench_verify.c bench_verify_LDADD = libsecp256k1.la $(SECP_LIBS) $(SECP_TEST_LIBS) $(COMMON_LIB) +# SECP_TEST_INCLUDES are only used here for CRYPTO_CPPFLAGS +bench_verify_CPPFLAGS = -DSECP256K1_BUILD $(SECP_TEST_INCLUDES) bench_sign_SOURCES = src/bench_sign.c bench_sign_LDADD = libsecp256k1.la $(SECP_LIBS) $(SECP_TEST_LIBS) $(COMMON_LIB) bench_internal_SOURCES = src/bench_internal.c @@ -99,6 +95,12 @@ if USE_TESTS noinst_PROGRAMS += tests tests_SOURCES = src/tests.c tests_CPPFLAGS = -DSECP256K1_BUILD -I$(top_srcdir)/src -I$(top_srcdir)/include $(SECP_INCLUDES) $(SECP_TEST_INCLUDES) +if VALGRIND_ENABLED +tests_CPPFLAGS += -DVALGRIND +noinst_PROGRAMS += valgrind_ctime_test +valgrind_ctime_test_SOURCES = src/valgrind_ctime_test.c +valgrind_ctime_test_LDADD = libsecp256k1.la $(SECP_LIBS) $(SECP_TEST_LIBS) $(COMMON_LIB) +endif if !ENABLE_COVERAGE tests_CPPFLAGS += -DVERIFY endif @@ -119,42 +121,12 @@ exhaustive_tests_LDFLAGS = -static TESTS += exhaustive_tests endif -JAVAROOT=src/java -JAVAORG=org/bitcoin -JAVA_GUAVA=$(srcdir)/$(JAVAROOT)/guava/guava-18.0.jar -CLASSPATH_ENV=CLASSPATH=$(JAVA_GUAVA) -JAVA_FILES= \ - $(JAVAROOT)/$(JAVAORG)/NativeSecp256k1.java \ - $(JAVAROOT)/$(JAVAORG)/NativeSecp256k1Test.java \ - $(JAVAROOT)/$(JAVAORG)/NativeSecp256k1Util.java \ - $(JAVAROOT)/$(JAVAORG)/Secp256k1Context.java - -if USE_JNI - -$(JAVA_GUAVA): - @echo Guava is missing. Fetch it via: \ - wget https://search.maven.org/remotecontent?filepath=com/google/guava/guava/18.0/guava-18.0.jar -O $(@) - @false - -.stamp-java: $(JAVA_FILES) - @echo Compiling $^ - $(AM_V_at)$(CLASSPATH_ENV) javac $^ - @touch $@ - -if USE_TESTS - -check-java: libsecp256k1.la $(JAVA_GUAVA) .stamp-java - $(AM_V_at)java -Djava.library.path="./:./src:./src/.libs:.libs/" -cp "$(JAVA_GUAVA):$(JAVAROOT)" $(JAVAORG)/NativeSecp256k1Test - -endif -endif - if USE_ECMULT_STATIC_PRECOMPUTATION -CPPFLAGS_FOR_BUILD +=-I$(top_srcdir) +CPPFLAGS_FOR_BUILD +=-I$(top_srcdir) -I$(builddir)/src gen_context_OBJECTS = gen_context.o gen_context_BIN = gen_context$(BUILD_EXEEXT) -gen_%.o: src/gen_%.c +gen_%.o: src/gen_%.c src/libsecp256k1-config.h $(CC_FOR_BUILD) $(CPPFLAGS_FOR_BUILD) $(CFLAGS_FOR_BUILD) -c $< -o $@ $(gen_context_BIN): $(gen_context_OBJECTS) @@ -168,10 +140,10 @@ $(bench_ecmult_OBJECTS): src/ecmult_static_context.h src/ecmult_static_context.h: $(gen_context_BIN) ./$(gen_context_BIN) -CLEANFILES = $(gen_context_BIN) src/ecmult_static_context.h $(JAVAROOT)/$(JAVAORG)/*.class .stamp-java +CLEANFILES = $(gen_context_BIN) src/ecmult_static_context.h endif -EXTRA_DIST = autogen.sh src/gen_context.c src/basic-config.h $(JAVA_FILES) +EXTRA_DIST = autogen.sh src/gen_context.c src/basic-config.h if ENABLE_MODULE_ECDH include src/modules/ecdh/Makefile.am.include diff --git a/src/secp256k1/README.md b/src/secp256k1/README.md index 8cd344ea81..434178b372 100644 --- a/src/secp256k1/README.md +++ b/src/secp256k1/README.md @@ -3,17 +3,22 @@ libsecp256k1 [![Build Status](https://travis-ci.org/bitcoin-core/secp256k1.svg?branch=master)](https://travis-ci.org/bitcoin-core/secp256k1) -Optimized C library for EC operations on curve secp256k1. +Optimized C library for ECDSA signatures and secret/public key operations on curve secp256k1. -This library is a work in progress and is being used to research best practices. Use at your own risk. +This library is intended to be the highest quality publicly available library for cryptography on the secp256k1 curve. However, the primary focus of its development has been for usage in the Bitcoin system and usage unlike Bitcoin's may be less well tested, verified, or suffer from a less well thought out interface. Correct usage requires some care and consideration that the library is fit for your application's purpose. Features: * secp256k1 ECDSA signing/verification and key generation. -* Adding/multiplying private/public keys. -* Serialization/parsing of private keys, public keys, signatures. -* Constant time, constant memory access signing and pubkey generation. -* Derandomized DSA (via RFC6979 or with a caller provided function.) +* Additive and multiplicative tweaking of secret/public keys. +* Serialization/parsing of secret keys, public keys, signatures. +* Constant time, constant memory access signing and public key generation. +* Derandomized ECDSA (via RFC6979 or with a caller provided function.) * Very efficient implementation. +* Suitable for embedded systems. +* Optional module for public key recovery. +* Optional module for ECDH key exchange (experimental). + +Experimental features have not received enough scrutiny to satisfy the standard of quality of this library but are made available for testing and review by the community. The APIs of these features should not be considered stable. Implementation details ---------------------- @@ -23,11 +28,12 @@ Implementation details * Extensive testing infrastructure. * Structured to facilitate review and analysis. * Intended to be portable to any system with a C89 compiler and uint64_t support. + * No use of floating types. * Expose only higher level interfaces to minimize the API surface and improve application security. ("Be difficult to use insecurely.") * Field operations * Optimized implementation of arithmetic modulo the curve's field size (2^256 - 0x1000003D1). * Using 5 52-bit limbs (including hand-optimized assembly for x86_64, by Diederik Huys). - * Using 10 26-bit limbs. + * Using 10 26-bit limbs (including hand-optimized assembly for 32-bit ARM, by Wladimir J. van der Laan). * Field inverses and square roots using a sliding window over blocks of 1s (by Peter Dettman). * Scalar operations * Optimized implementation without data-dependent branches of arithmetic modulo the curve's order. @@ -45,9 +51,11 @@ Implementation details * Optionally (off by default) use secp256k1's efficiently-computable endomorphism to split the P multiplicand into 2 half-sized ones. * Point multiplication for signing * Use a precomputed table of multiples of powers of 16 multiplied with the generator, so general multiplication becomes a series of additions. - * Access the table with branch-free conditional moves so memory access is uniform. - * No data-dependent branches - * The precomputed tables add and eventually subtract points for which no known scalar (private key) is known, preventing even an attacker with control over the private key used to control the data internally. + * Intended to be completely free of timing sidechannels for secret-key operations (on reasonable hardware/toolchains) + * Access the table with branch-free conditional moves so memory access is uniform. + * No data-dependent branches + * Optional runtime blinding which attempts to frustrate differential power analysis. + * The precomputed tables add and eventually subtract points for which no known scalar (secret key) is known, preventing even an attacker with control over the secret key used to control the data internally. Build steps ----------- @@ -57,5 +65,40 @@ libsecp256k1 is built using autotools: $ ./autogen.sh $ ./configure $ make - $ ./tests + $ make check $ sudo make install # optional + +Exhaustive tests +----------- + + $ ./exhaustive_tests + +With valgrind, you might need to increase the max stack size: + + $ valgrind --max-stackframe=2500000 ./exhaustive_tests + +Test coverage +----------- + +This library aims to have full coverage of the reachable lines and branches. + +To create a test coverage report, configure with `--enable-coverage` (use of GCC is necessary): + + $ ./configure --enable-coverage + +Run the tests: + + $ make check + +To create a report, `gcovr` is recommended, as it includes branch coverage reporting: + + $ gcovr --exclude 'src/bench*' --print-summary + +To create a HTML report with coloured and annotated source code: + + $ gcovr --exclude 'src/bench*' --html --html-details -o coverage.html + +Reporting a vulnerability +------------ + +See [SECURITY.md](SECURITY.md) diff --git a/src/secp256k1/SECURITY.md b/src/secp256k1/SECURITY.md new file mode 100644 index 0000000000..0e4d588030 --- /dev/null +++ b/src/secp256k1/SECURITY.md @@ -0,0 +1,15 @@ +# Security Policy + +## Reporting a Vulnerability + +To report security issues send an email to secp256k1-security@bitcoincore.org (not for support). + +The following keys may be used to communicate sensitive information to developers: + +| Name | Fingerprint | +|------|-------------| +| Pieter Wuille | 133E AC17 9436 F14A 5CF1 B794 860F EB80 4E66 9320 | +| Andrew Poelstra | 699A 63EF C17A D3A9 A34C FFC0 7AD0 A91C 40BD 0091 | +| Tim Ruffing | 09E0 3F87 1092 E40E 106E 902B 33BC 86AB 80FF 5516 | + +You can import a key by running the following command with that individual’s fingerprint: `gpg --recv-keys "<fingerprint>"` Ensure that you put quotes around fingerprints containing spaces. diff --git a/src/secp256k1/build-aux/m4/ax_jni_include_dir.m4 b/src/secp256k1/build-aux/m4/ax_jni_include_dir.m4 deleted file mode 100644 index cdc78d87d4..0000000000 --- a/src/secp256k1/build-aux/m4/ax_jni_include_dir.m4 +++ /dev/null @@ -1,145 +0,0 @@ -# =========================================================================== -# https://www.gnu.org/software/autoconf-archive/ax_jni_include_dir.html -# =========================================================================== -# -# SYNOPSIS -# -# AX_JNI_INCLUDE_DIR -# -# DESCRIPTION -# -# AX_JNI_INCLUDE_DIR finds include directories needed for compiling -# programs using the JNI interface. -# -# JNI include directories are usually in the Java distribution. This is -# deduced from the value of $JAVA_HOME, $JAVAC, or the path to "javac", in -# that order. When this macro completes, a list of directories is left in -# the variable JNI_INCLUDE_DIRS. -# -# Example usage follows: -# -# AX_JNI_INCLUDE_DIR -# -# for JNI_INCLUDE_DIR in $JNI_INCLUDE_DIRS -# do -# CPPFLAGS="$CPPFLAGS -I$JNI_INCLUDE_DIR" -# done -# -# If you want to force a specific compiler: -# -# - at the configure.in level, set JAVAC=yourcompiler before calling -# AX_JNI_INCLUDE_DIR -# -# - at the configure level, setenv JAVAC -# -# Note: This macro can work with the autoconf M4 macros for Java programs. -# This particular macro is not part of the original set of macros. -# -# LICENSE -# -# Copyright (c) 2008 Don Anderson <dda@sleepycat.com> -# -# Copying and distribution of this file, with or without modification, are -# permitted in any medium without royalty provided the copyright notice -# and this notice are preserved. This file is offered as-is, without any -# warranty. - -#serial 14 - -AU_ALIAS([AC_JNI_INCLUDE_DIR], [AX_JNI_INCLUDE_DIR]) -AC_DEFUN([AX_JNI_INCLUDE_DIR],[ - -JNI_INCLUDE_DIRS="" - -if test "x$JAVA_HOME" != x; then - _JTOPDIR="$JAVA_HOME" -else - if test "x$JAVAC" = x; then - JAVAC=javac - fi - AC_PATH_PROG([_ACJNI_JAVAC], [$JAVAC], [no]) - if test "x$_ACJNI_JAVAC" = xno; then - AC_MSG_WARN([cannot find JDK; try setting \$JAVAC or \$JAVA_HOME]) - fi - _ACJNI_FOLLOW_SYMLINKS("$_ACJNI_JAVAC") - _JTOPDIR=`echo "$_ACJNI_FOLLOWED" | sed -e 's://*:/:g' -e 's:/[[^/]]*$::'` -fi - -case "$host_os" in - darwin*) # Apple Java headers are inside the Xcode bundle. - macos_version=$(sw_vers -productVersion | sed -n -e 's/^@<:@0-9@:>@*.\(@<:@0-9@:>@*\).@<:@0-9@:>@*/\1/p') - if @<:@ "$macos_version" -gt "7" @:>@; then - _JTOPDIR="$(xcrun --show-sdk-path)/System/Library/Frameworks/JavaVM.framework" - _JINC="$_JTOPDIR/Headers" - else - _JTOPDIR="/System/Library/Frameworks/JavaVM.framework" - _JINC="$_JTOPDIR/Headers" - fi - ;; - *) _JINC="$_JTOPDIR/include";; -esac -_AS_ECHO_LOG([_JTOPDIR=$_JTOPDIR]) -_AS_ECHO_LOG([_JINC=$_JINC]) - -# On Mac OS X 10.6.4, jni.h is a symlink: -# /System/Library/Frameworks/JavaVM.framework/Versions/Current/Headers/jni.h -# -> ../../CurrentJDK/Headers/jni.h. -AC_CACHE_CHECK(jni headers, ac_cv_jni_header_path, -[ - if test -f "$_JINC/jni.h"; then - ac_cv_jni_header_path="$_JINC" - JNI_INCLUDE_DIRS="$JNI_INCLUDE_DIRS $ac_cv_jni_header_path" - else - _JTOPDIR=`echo "$_JTOPDIR" | sed -e 's:/[[^/]]*$::'` - if test -f "$_JTOPDIR/include/jni.h"; then - ac_cv_jni_header_path="$_JTOPDIR/include" - JNI_INCLUDE_DIRS="$JNI_INCLUDE_DIRS $ac_cv_jni_header_path" - else - ac_cv_jni_header_path=none - fi - fi -]) - -# get the likely subdirectories for system specific java includes -case "$host_os" in -bsdi*) _JNI_INC_SUBDIRS="bsdos";; -freebsd*) _JNI_INC_SUBDIRS="freebsd";; -darwin*) _JNI_INC_SUBDIRS="darwin";; -linux*) _JNI_INC_SUBDIRS="linux genunix";; -osf*) _JNI_INC_SUBDIRS="alpha";; -solaris*) _JNI_INC_SUBDIRS="solaris";; -mingw*) _JNI_INC_SUBDIRS="win32";; -cygwin*) _JNI_INC_SUBDIRS="win32";; -*) _JNI_INC_SUBDIRS="genunix";; -esac - -if test "x$ac_cv_jni_header_path" != "xnone"; then - # add any subdirectories that are present - for JINCSUBDIR in $_JNI_INC_SUBDIRS - do - if test -d "$_JTOPDIR/include/$JINCSUBDIR"; then - JNI_INCLUDE_DIRS="$JNI_INCLUDE_DIRS $_JTOPDIR/include/$JINCSUBDIR" - fi - done -fi -]) - -# _ACJNI_FOLLOW_SYMLINKS <path> -# Follows symbolic links on <path>, -# finally setting variable _ACJNI_FOLLOWED -# ---------------------------------------- -AC_DEFUN([_ACJNI_FOLLOW_SYMLINKS],[ -# find the include directory relative to the javac executable -_cur="$1" -while ls -ld "$_cur" 2>/dev/null | grep " -> " >/dev/null; do - AC_MSG_CHECKING([symlink for $_cur]) - _slink=`ls -ld "$_cur" | sed 's/.* -> //'` - case "$_slink" in - /*) _cur="$_slink";; - # 'X' avoids triggering unwanted echo options. - *) _cur=`echo "X$_cur" | sed -e 's/^X//' -e 's:[[^/]]*$::'`"$_slink";; - esac - AC_MSG_RESULT([$_cur]) -done -_ACJNI_FOLLOWED="$_cur" -])# _ACJNI diff --git a/src/secp256k1/build-aux/m4/bitcoin_secp.m4 b/src/secp256k1/build-aux/m4/bitcoin_secp.m4 index 3b3975cbdd..1b2b71e6ab 100644 --- a/src/secp256k1/build-aux/m4/bitcoin_secp.m4 +++ b/src/secp256k1/build-aux/m4/bitcoin_secp.m4 @@ -38,6 +38,8 @@ AC_DEFUN([SECP_OPENSSL_CHECK],[ fi if test x"$has_libcrypto" = x"yes" && test x"$has_openssl_ec" = x; then AC_MSG_CHECKING(for EC functions in libcrypto) + CPPFLAGS_TEMP="$CPPFLAGS" + CPPFLAGS="$CRYPTO_CPPFLAGS $CPPFLAGS" AC_COMPILE_IFELSE([AC_LANG_PROGRAM([[ #include <openssl/ec.h> #include <openssl/ecdsa.h> @@ -51,6 +53,7 @@ if test x"$has_libcrypto" = x"yes" && test x"$has_openssl_ec" = x; then ECDSA_SIG_free(sig_openssl); ]])],[has_openssl_ec=yes],[has_openssl_ec=no]) AC_MSG_RESULT([$has_openssl_ec]) + CPPFLAGS="$CPPFLAGS_TEMP" fi ]) diff --git a/src/secp256k1/configure.ac b/src/secp256k1/configure.ac index 3b7a328c8a..6021b760b5 100644 --- a/src/secp256k1/configure.ac +++ b/src/secp256k1/configure.ac @@ -7,6 +7,11 @@ AH_TOP([#ifndef LIBSECP256K1_CONFIG_H]) AH_TOP([#define LIBSECP256K1_CONFIG_H]) AH_BOTTOM([#endif /*LIBSECP256K1_CONFIG_H*/]) AM_INIT_AUTOMAKE([foreign subdir-objects]) + +# Set -g if CFLAGS are not already set, which matches the default autoconf +# behavior (see PROG_CC in the Autoconf manual) with the exception that we don't +# set -O2 here because we set it in any case (see further down). +: ${CFLAGS="-g"} LT_INIT dnl make the compilation flags quiet unless V=1 is used @@ -19,10 +24,6 @@ AC_PATH_TOOL(RANLIB, ranlib) AC_PATH_TOOL(STRIP, strip) AX_PROG_CC_FOR_BUILD -if test "x$CFLAGS" = "x"; then - CFLAGS="-g" -fi - AM_PROG_CC_C_O AC_PROG_CC_C89 @@ -45,6 +46,7 @@ case $host_os in if test x$openssl_prefix != x; then PKG_CONFIG_PATH="$openssl_prefix/lib/pkgconfig:$PKG_CONFIG_PATH" export PKG_CONFIG_PATH + CRYPTO_CPPFLAGS="-I$openssl_prefix/include" fi if test x$gmp_prefix != x; then GMP_CPPFLAGS="-I$gmp_prefix/include" @@ -63,11 +65,11 @@ case $host_os in ;; esac -CFLAGS="$CFLAGS -W" +CFLAGS="-W $CFLAGS" warn_CFLAGS="-std=c89 -pedantic -Wall -Wextra -Wcast-align -Wnested-externs -Wshadow -Wstrict-prototypes -Wno-unused-function -Wno-long-long -Wno-overlength-strings" saved_CFLAGS="$CFLAGS" -CFLAGS="$CFLAGS $warn_CFLAGS" +CFLAGS="$warn_CFLAGS $CFLAGS" AC_MSG_CHECKING([if ${CC} supports ${warn_CFLAGS}]) AC_COMPILE_IFELSE([AC_LANG_SOURCE([[char foo;]])], [ AC_MSG_RESULT([yes]) ], @@ -76,7 +78,7 @@ AC_COMPILE_IFELSE([AC_LANG_SOURCE([[char foo;]])], ]) saved_CFLAGS="$CFLAGS" -CFLAGS="$CFLAGS -fvisibility=hidden" +CFLAGS="-fvisibility=hidden $CFLAGS" AC_MSG_CHECKING([if ${CC} supports -fvisibility=hidden]) AC_COMPILE_IFELSE([AC_LANG_SOURCE([[char foo;]])], [ AC_MSG_RESULT([yes]) ], @@ -85,42 +87,42 @@ AC_COMPILE_IFELSE([AC_LANG_SOURCE([[char foo;]])], ]) AC_ARG_ENABLE(benchmark, - AS_HELP_STRING([--enable-benchmark],[compile benchmark (default is yes)]), + AS_HELP_STRING([--enable-benchmark],[compile benchmark [default=yes]]), [use_benchmark=$enableval], [use_benchmark=yes]) AC_ARG_ENABLE(coverage, - AS_HELP_STRING([--enable-coverage],[enable compiler flags to support kcov coverage analysis]), + AS_HELP_STRING([--enable-coverage],[enable compiler flags to support kcov coverage analysis [default=no]]), [enable_coverage=$enableval], [enable_coverage=no]) AC_ARG_ENABLE(tests, - AS_HELP_STRING([--enable-tests],[compile tests (default is yes)]), + AS_HELP_STRING([--enable-tests],[compile tests [default=yes]]), [use_tests=$enableval], [use_tests=yes]) AC_ARG_ENABLE(openssl_tests, - AS_HELP_STRING([--enable-openssl-tests],[enable OpenSSL tests, if OpenSSL is available (default is auto)]), + AS_HELP_STRING([--enable-openssl-tests],[enable OpenSSL tests [default=auto]]), [enable_openssl_tests=$enableval], [enable_openssl_tests=auto]) AC_ARG_ENABLE(experimental, - AS_HELP_STRING([--enable-experimental],[allow experimental configure options (default is no)]), + AS_HELP_STRING([--enable-experimental],[allow experimental configure options [default=no]]), [use_experimental=$enableval], [use_experimental=no]) AC_ARG_ENABLE(exhaustive_tests, - AS_HELP_STRING([--enable-exhaustive-tests],[compile exhaustive tests (default is yes)]), + AS_HELP_STRING([--enable-exhaustive-tests],[compile exhaustive tests [default=yes]]), [use_exhaustive_tests=$enableval], [use_exhaustive_tests=yes]) AC_ARG_ENABLE(endomorphism, - AS_HELP_STRING([--enable-endomorphism],[enable endomorphism (default is no)]), + AS_HELP_STRING([--enable-endomorphism],[enable endomorphism [default=no]]), [use_endomorphism=$enableval], [use_endomorphism=no]) AC_ARG_ENABLE(ecmult_static_precomputation, - AS_HELP_STRING([--enable-ecmult-static-precomputation],[enable precomputed ecmult table for signing (default is yes)]), + AS_HELP_STRING([--enable-ecmult-static-precomputation],[enable precomputed ecmult table for signing [default=auto]]), [use_ecmult_static_precomputation=$enableval], [use_ecmult_static_precomputation=auto]) @@ -130,35 +132,55 @@ AC_ARG_ENABLE(module_ecdh, [enable_module_ecdh=no]) AC_ARG_ENABLE(module_recovery, - AS_HELP_STRING([--enable-module-recovery],[enable ECDSA pubkey recovery module (default is no)]), + AS_HELP_STRING([--enable-module-recovery],[enable ECDSA pubkey recovery module [default=no]]), [enable_module_recovery=$enableval], [enable_module_recovery=no]) -AC_ARG_ENABLE(jni, - AS_HELP_STRING([--enable-jni],[enable libsecp256k1_jni (default is no)]), - [use_jni=$enableval], - [use_jni=no]) +AC_ARG_ENABLE(external_default_callbacks, + AS_HELP_STRING([--enable-external-default-callbacks],[enable external default callback functions [default=no]]), + [use_external_default_callbacks=$enableval], + [use_external_default_callbacks=no]) AC_ARG_WITH([field], [AS_HELP_STRING([--with-field=64bit|32bit|auto], -[Specify Field Implementation. Default is auto])],[req_field=$withval], [req_field=auto]) +[finite field implementation to use [default=auto]])],[req_field=$withval], [req_field=auto]) AC_ARG_WITH([bignum], [AS_HELP_STRING([--with-bignum=gmp|no|auto], -[Specify Bignum Implementation. Default is auto])],[req_bignum=$withval], [req_bignum=auto]) +[bignum implementation to use [default=auto]])],[req_bignum=$withval], [req_bignum=auto]) AC_ARG_WITH([scalar], [AS_HELP_STRING([--with-scalar=64bit|32bit|auto], -[Specify scalar implementation. Default is auto])],[req_scalar=$withval], [req_scalar=auto]) - -AC_ARG_WITH([asm], [AS_HELP_STRING([--with-asm=x86_64|arm|no|auto] -[Specify assembly optimizations to use. Default is auto (experimental: arm)])],[req_asm=$withval], [req_asm=auto]) +[scalar implementation to use [default=auto]])],[req_scalar=$withval], [req_scalar=auto]) + +AC_ARG_WITH([asm], [AS_HELP_STRING([--with-asm=x86_64|arm|no|auto], +[assembly optimizations to use (experimental: arm) [default=auto]])],[req_asm=$withval], [req_asm=auto]) + +AC_ARG_WITH([ecmult-window], [AS_HELP_STRING([--with-ecmult-window=SIZE|auto], +[window size for ecmult precomputation for verification, specified as integer in range [2..24].] +[Larger values result in possibly better performance at the cost of an exponentially larger precomputed table.] +[The table will store 2^(SIZE-2) * 64 bytes of data but can be larger in memory due to platform-specific padding and alignment.] +[If the endomorphism optimization is enabled, two tables of this size are used instead of only one.] +["auto" is a reasonable setting for desktop machines (currently 15). [default=auto]] +)], +[req_ecmult_window=$withval], [req_ecmult_window=auto]) + +AC_ARG_WITH([ecmult-gen-precision], [AS_HELP_STRING([--with-ecmult-gen-precision=2|4|8|auto], +[Precision bits to tune the precomputed table size for signing.] +[The size of the table is 32kB for 2 bits, 64kB for 4 bits, 512kB for 8 bits of precision.] +[A larger table size usually results in possible faster signing.] +["auto" is a reasonable setting for desktop machines (currently 4). [default=auto]] +)], +[req_ecmult_gen_precision=$withval], [req_ecmult_gen_precision=auto]) AC_CHECK_TYPES([__int128]) +AC_CHECK_HEADER([valgrind/memcheck.h], [enable_valgrind=yes], [enable_valgrind=no], []) +AM_CONDITIONAL([VALGRIND_ENABLED],[test "$enable_valgrind" = "yes"]) + if test x"$enable_coverage" = x"yes"; then AC_DEFINE(COVERAGE, 1, [Define this symbol to compile out all VERIFY code]) - CFLAGS="$CFLAGS -O0 --coverage" - LDFLAGS="--coverage" + CFLAGS="-O0 --coverage $CFLAGS" + LDFLAGS="--coverage $LDFLAGS" else - CFLAGS="$CFLAGS -O3" + CFLAGS="-O2 $CFLAGS" fi if test x"$use_ecmult_static_precomputation" != x"no"; then @@ -176,7 +198,7 @@ if test x"$use_ecmult_static_precomputation" != x"no"; then warn_CFLAGS_FOR_BUILD="-Wall -Wextra -Wno-unused-function" saved_CFLAGS="$CFLAGS" - CFLAGS="$CFLAGS $warn_CFLAGS_FOR_BUILD" + CFLAGS="$warn_CFLAGS_FOR_BUILD $CFLAGS" AC_MSG_CHECKING([if native ${CC_FOR_BUILD} supports ${warn_CFLAGS_FOR_BUILD}]) AC_COMPILE_IFELSE([AC_LANG_SOURCE([[char foo;]])], [ AC_MSG_RESULT([yes]) ], @@ -188,7 +210,7 @@ if test x"$use_ecmult_static_precomputation" != x"no"; then AC_RUN_IFELSE( [AC_LANG_PROGRAM([], [])], [working_native_cc=yes], - [working_native_cc=no],[dnl]) + [working_native_cc=no],[:]) CFLAGS_FOR_BUILD="$CFLAGS" @@ -387,12 +409,50 @@ case $set_scalar in ;; esac +#set ecmult window size +if test x"$req_ecmult_window" = x"auto"; then + set_ecmult_window=15 +else + set_ecmult_window=$req_ecmult_window +fi + +error_window_size=['window size for ecmult precomputation not an integer in range [2..24] or "auto"'] +case $set_ecmult_window in +''|*[[!0-9]]*) + # no valid integer + AC_MSG_ERROR($error_window_size) + ;; +*) + if test "$set_ecmult_window" -lt 2 -o "$set_ecmult_window" -gt 24 ; then + # not in range + AC_MSG_ERROR($error_window_size) + fi + AC_DEFINE_UNQUOTED(ECMULT_WINDOW_SIZE, $set_ecmult_window, [Set window size for ecmult precomputation]) + ;; +esac + +#set ecmult gen precision +if test x"$req_ecmult_gen_precision" = x"auto"; then + set_ecmult_gen_precision=4 +else + set_ecmult_gen_precision=$req_ecmult_gen_precision +fi + +case $set_ecmult_gen_precision in +2|4|8) + AC_DEFINE_UNQUOTED(ECMULT_GEN_PREC_BITS, $set_ecmult_gen_precision, [Set ecmult gen precision bits]) + ;; +*) + AC_MSG_ERROR(['ecmult gen precision not 2, 4, 8 or "auto"']) + ;; +esac + if test x"$use_tests" = x"yes"; then SECP_OPENSSL_CHECK if test x"$has_openssl_ec" = x"yes"; then if test x"$enable_openssl_tests" != x"no"; then AC_DEFINE(ENABLE_OPENSSL_TESTS, 1, [Define this symbol if OpenSSL EC functions are available]) - SECP_TEST_INCLUDES="$SSL_CFLAGS $CRYPTO_CFLAGS" + SECP_TEST_INCLUDES="$SSL_CFLAGS $CRYPTO_CFLAGS $CRYPTO_CPPFLAGS" SECP_TEST_LIBS="$CRYPTO_LIBS" case $host in @@ -412,29 +472,6 @@ else fi fi -if test x"$use_jni" != x"no"; then - AX_JNI_INCLUDE_DIR - have_jni_dependencies=yes - if test x"$enable_module_ecdh" = x"no"; then - have_jni_dependencies=no - fi - if test "x$JNI_INCLUDE_DIRS" = "x"; then - have_jni_dependencies=no - fi - if test "x$have_jni_dependencies" = "xno"; then - if test x"$use_jni" = x"yes"; then - AC_MSG_ERROR([jni support explicitly requested but headers/dependencies were not found. Enable ECDH and try again.]) - fi - AC_MSG_WARN([jni headers/dependencies not found. jni support disabled]) - use_jni=no - else - use_jni=yes - for JNI_INCLUDE_DIR in $JNI_INCLUDE_DIRS; do - JNI_INCLUDES="$JNI_INCLUDES -I$JNI_INCLUDE_DIR" - done - fi -fi - if test x"$set_bignum" = x"gmp"; then SECP_LIBS="$SECP_LIBS $GMP_LIBS" SECP_INCLUDES="$SECP_INCLUDES $GMP_CPPFLAGS" @@ -462,6 +499,10 @@ if test x"$use_external_asm" = x"yes"; then AC_DEFINE(USE_EXTERNAL_ASM, 1, [Define this symbol if an external (non-inline) assembly implementation is used]) fi +if test x"$use_external_default_callbacks" = x"yes"; then + AC_DEFINE(USE_EXTERNAL_DEFAULT_CALLBACKS, 1, [Define this symbol if an external implementation of the default callbacks is used]) +fi + if test x"$enable_experimental" = x"yes"; then AC_MSG_NOTICE([******]) AC_MSG_NOTICE([WARNING: experimental build]) @@ -479,7 +520,6 @@ fi AC_CONFIG_HEADERS([src/libsecp256k1-config.h]) AC_CONFIG_FILES([Makefile libsecp256k1.pc]) -AC_SUBST(JNI_INCLUDES) AC_SUBST(SECP_INCLUDES) AC_SUBST(SECP_LIBS) AC_SUBST(SECP_TEST_LIBS) @@ -491,7 +531,6 @@ AM_CONDITIONAL([USE_BENCHMARK], [test x"$use_benchmark" = x"yes"]) AM_CONDITIONAL([USE_ECMULT_STATIC_PRECOMPUTATION], [test x"$set_precomp" = x"yes"]) AM_CONDITIONAL([ENABLE_MODULE_ECDH], [test x"$enable_module_ecdh" = x"yes"]) AM_CONDITIONAL([ENABLE_MODULE_RECOVERY], [test x"$enable_module_recovery" = x"yes"]) -AM_CONDITIONAL([USE_JNI], [test x"$use_jni" = x"yes"]) AM_CONDITIONAL([USE_EXTERNAL_ASM], [test x"$use_external_asm" = x"yes"]) AM_CONDITIONAL([USE_ASM_ARM], [test x"$set_asm" = x"arm"]) @@ -504,21 +543,24 @@ AC_OUTPUT echo echo "Build Options:" -echo " with endomorphism = $use_endomorphism" -echo " with ecmult precomp = $set_precomp" -echo " with jni = $use_jni" -echo " with benchmarks = $use_benchmark" -echo " with coverage = $enable_coverage" -echo " module ecdh = $enable_module_ecdh" -echo " module recovery = $enable_module_recovery" +echo " with endomorphism = $use_endomorphism" +echo " with ecmult precomp = $set_precomp" +echo " with external callbacks = $use_external_default_callbacks" +echo " with benchmarks = $use_benchmark" +echo " with coverage = $enable_coverage" +echo " module ecdh = $enable_module_ecdh" +echo " module recovery = $enable_module_recovery" echo -echo " asm = $set_asm" -echo " bignum = $set_bignum" -echo " field = $set_field" -echo " scalar = $set_scalar" +echo " asm = $set_asm" +echo " bignum = $set_bignum" +echo " field = $set_field" +echo " scalar = $set_scalar" +echo " ecmult window size = $set_ecmult_window" +echo " ecmult gen prec. bits = $set_ecmult_gen_precision" echo -echo " CC = $CC" -echo " CFLAGS = $CFLAGS" -echo " CPPFLAGS = $CPPFLAGS" -echo " LDFLAGS = $LDFLAGS" +echo " valgrind = $enable_valgrind" +echo " CC = $CC" +echo " CFLAGS = $CFLAGS" +echo " CPPFLAGS = $CPPFLAGS" +echo " LDFLAGS = $LDFLAGS" echo diff --git a/src/secp256k1/contrib/lax_der_parsing.c b/src/secp256k1/contrib/lax_der_parsing.c index 5b141a9948..e177a0562d 100644 --- a/src/secp256k1/contrib/lax_der_parsing.c +++ b/src/secp256k1/contrib/lax_der_parsing.c @@ -32,7 +32,7 @@ int ecdsa_signature_parse_der_lax(const secp256k1_context* ctx, secp256k1_ecdsa_ lenbyte = input[pos++]; if (lenbyte & 0x80) { lenbyte -= 0x80; - if (pos + lenbyte > inputlen) { + if (lenbyte > inputlen - pos) { return 0; } pos += lenbyte; @@ -51,7 +51,7 @@ int ecdsa_signature_parse_der_lax(const secp256k1_context* ctx, secp256k1_ecdsa_ lenbyte = input[pos++]; if (lenbyte & 0x80) { lenbyte -= 0x80; - if (pos + lenbyte > inputlen) { + if (lenbyte > inputlen - pos) { return 0; } while (lenbyte > 0 && input[pos] == 0) { @@ -89,7 +89,7 @@ int ecdsa_signature_parse_der_lax(const secp256k1_context* ctx, secp256k1_ecdsa_ lenbyte = input[pos++]; if (lenbyte & 0x80) { lenbyte -= 0x80; - if (pos + lenbyte > inputlen) { + if (lenbyte > inputlen - pos) { return 0; } while (lenbyte > 0 && input[pos] == 0) { diff --git a/src/secp256k1/contrib/travis.sh b/src/secp256k1/contrib/travis.sh new file mode 100755 index 0000000000..3909d16a27 --- /dev/null +++ b/src/secp256k1/contrib/travis.sh @@ -0,0 +1,65 @@ +#!/bin/sh + +set -e +set -x + +if [ -n "$HOST" ] +then + export USE_HOST="--host=$HOST" +fi +if [ "$HOST" = "i686-linux-gnu" ] +then + export CC="$CC -m32" +fi +if [ "$TRAVIS_OS_NAME" = "osx" ] && [ "$TRAVIS_COMPILER" = "gcc" ] +then + export CC="gcc-9" +fi + +./configure \ + --enable-experimental="$EXPERIMENTAL" --enable-endomorphism="$ENDOMORPHISM" \ + --with-field="$FIELD" --with-bignum="$BIGNUM" --with-asm="$ASM" --with-scalar="$SCALAR" \ + --enable-ecmult-static-precomputation="$STATICPRECOMPUTATION" --with-ecmult-gen-precision="$ECMULTGENPRECISION" \ + --enable-module-ecdh="$ECDH" --enable-module-recovery="$RECOVERY" "$EXTRAFLAGS" "$USE_HOST" + +if [ -n "$BUILD" ] +then + make -j2 "$BUILD" +fi +if [ -n "$VALGRIND" ] +then + make -j2 + # the `--error-exitcode` is required to make the test fail if valgrind found errors, otherwise it'll return 0 (http://valgrind.org/docs/manual/manual-core.html) + valgrind --error-exitcode=42 ./tests 16 + valgrind --error-exitcode=42 ./exhaustive_tests +fi +if [ -n "$BENCH" ] +then + if [ -n "$VALGRIND" ] + then + # Using the local `libtool` because on macOS the system's libtool has nothing to do with GNU libtool + EXEC='./libtool --mode=execute valgrind --error-exitcode=42' + else + EXEC= + fi + # This limits the iterations in the benchmarks below to ITER(set in .travis.yml) iterations. + export SECP256K1_BENCH_ITERS="$ITERS" + { + $EXEC ./bench_ecmult + $EXEC ./bench_internal + $EXEC ./bench_sign + $EXEC ./bench_verify + } >> bench.log 2>&1 + if [ "$RECOVERY" = "yes" ] + then + $EXEC ./bench_recover >> bench.log 2>&1 + fi + if [ "$ECDH" = "yes" ] + then + $EXEC ./bench_ecdh >> bench.log 2>&1 + fi +fi +if [ -n "$CTIMETEST" ] +then + ./libtool --mode=execute valgrind --error-exitcode=42 ./valgrind_ctime_test > valgrind_ctime_test.log 2>&1 +fi diff --git a/src/secp256k1/include/secp256k1.h b/src/secp256k1/include/secp256k1.h index 43af09c330..2ba2dca388 100644 --- a/src/secp256k1/include/secp256k1.h +++ b/src/secp256k1/include/secp256k1.h @@ -14,7 +14,7 @@ extern "C" { * 2. Array lengths always immediately the follow the argument whose length * they describe, even if this violates rule 1. * 3. Within the OUT/OUTIN/IN groups, pointers to data that is typically generated - * later go first. This means: signatures, public nonces, private nonces, + * later go first. This means: signatures, public nonces, secret nonces, * messages, public keys, secret keys, tweaks. * 4. Arguments that are not data pointers go last, from more complex to less * complex: function pointers, algorithm names, messages, void pointers, @@ -33,9 +33,10 @@ extern "C" { * verification). * * A constructed context can safely be used from multiple threads - * simultaneously, but API call that take a non-const pointer to a context + * simultaneously, but API calls that take a non-const pointer to a context * need exclusive access to it. In particular this is the case for - * secp256k1_context_destroy and secp256k1_context_randomize. + * secp256k1_context_destroy, secp256k1_context_preallocated_destroy, + * and secp256k1_context_randomize. * * Regarding randomization, either do it once at creation time (in which case * you do not need any locking for the other calls), or use a read-write lock. @@ -161,14 +162,17 @@ typedef int (*secp256k1_nonce_function)( /** The higher bits contain the actual data. Do not use directly. */ #define SECP256K1_FLAGS_BIT_CONTEXT_VERIFY (1 << 8) #define SECP256K1_FLAGS_BIT_CONTEXT_SIGN (1 << 9) +#define SECP256K1_FLAGS_BIT_CONTEXT_DECLASSIFY (1 << 10) #define SECP256K1_FLAGS_BIT_COMPRESSION (1 << 8) -/** Flags to pass to secp256k1_context_create. */ +/** Flags to pass to secp256k1_context_create, secp256k1_context_preallocated_size, and + * secp256k1_context_preallocated_create. */ #define SECP256K1_CONTEXT_VERIFY (SECP256K1_FLAGS_TYPE_CONTEXT | SECP256K1_FLAGS_BIT_CONTEXT_VERIFY) #define SECP256K1_CONTEXT_SIGN (SECP256K1_FLAGS_TYPE_CONTEXT | SECP256K1_FLAGS_BIT_CONTEXT_SIGN) +#define SECP256K1_CONTEXT_DECLASSIFY (SECP256K1_FLAGS_TYPE_CONTEXT | SECP256K1_FLAGS_BIT_CONTEXT_DECLASSIFY) #define SECP256K1_CONTEXT_NONE (SECP256K1_FLAGS_TYPE_CONTEXT) -/** Flag to pass to secp256k1_ec_pubkey_serialize and secp256k1_ec_privkey_export. */ +/** Flag to pass to secp256k1_ec_pubkey_serialize. */ #define SECP256K1_EC_COMPRESSED (SECP256K1_FLAGS_TYPE_COMPRESSION | SECP256K1_FLAGS_BIT_COMPRESSION) #define SECP256K1_EC_UNCOMPRESSED (SECP256K1_FLAGS_TYPE_COMPRESSION) @@ -186,7 +190,11 @@ typedef int (*secp256k1_nonce_function)( */ SECP256K1_API extern const secp256k1_context *secp256k1_context_no_precomp; -/** Create a secp256k1 context object. +/** Create a secp256k1 context object (in dynamically allocated memory). + * + * This function uses malloc to allocate memory. It is guaranteed that malloc is + * called at most once for every call of this function. If you need to avoid dynamic + * memory allocation entirely, see the functions in secp256k1_preallocated.h. * * Returns: a newly created context object. * In: flags: which parts of the context to initialize. @@ -197,7 +205,11 @@ SECP256K1_API secp256k1_context* secp256k1_context_create( unsigned int flags ) SECP256K1_WARN_UNUSED_RESULT; -/** Copies a secp256k1 context object. +/** Copy a secp256k1 context object (into dynamically allocated memory). + * + * This function uses malloc to allocate memory. It is guaranteed that malloc is + * called at most once for every call of this function. If you need to avoid dynamic + * memory allocation entirely, see the functions in secp256k1_preallocated.h. * * Returns: a newly created context object. * Args: ctx: an existing context to copy (cannot be NULL) @@ -206,10 +218,18 @@ SECP256K1_API secp256k1_context* secp256k1_context_clone( const secp256k1_context* ctx ) SECP256K1_ARG_NONNULL(1) SECP256K1_WARN_UNUSED_RESULT; -/** Destroy a secp256k1 context object. +/** Destroy a secp256k1 context object (created in dynamically allocated memory). * * The context pointer may not be used afterwards. - * Args: ctx: an existing context to destroy (cannot be NULL) + * + * The context to destroy must have been created using secp256k1_context_create + * or secp256k1_context_clone. If the context has instead been created using + * secp256k1_context_preallocated_create or secp256k1_context_preallocated_clone, the + * behaviour is undefined. In that case, secp256k1_context_preallocated_destroy must + * be used instead. + * + * Args: ctx: an existing context to destroy, constructed using + * secp256k1_context_create or secp256k1_context_clone */ SECP256K1_API void secp256k1_context_destroy( secp256k1_context* ctx @@ -229,11 +249,28 @@ SECP256K1_API void secp256k1_context_destroy( * to cause a crash, though its return value and output arguments are * undefined. * + * When this function has not been called (or called with fn==NULL), then the + * default handler will be used. The library provides a default handler which + * writes the message to stderr and calls abort. This default handler can be + * replaced at link time if the preprocessor macro + * USE_EXTERNAL_DEFAULT_CALLBACKS is defined, which is the case if the build + * has been configured with --enable-external-default-callbacks. Then the + * following two symbols must be provided to link against: + * - void secp256k1_default_illegal_callback_fn(const char* message, void* data); + * - void secp256k1_default_error_callback_fn(const char* message, void* data); + * The library can call these default handlers even before a proper callback data + * pointer could have been set using secp256k1_context_set_illegal_callback or + * secp256k1_context_set_error_callback, e.g., when the creation of a context + * fails. In this case, the corresponding default handler will be called with + * the data pointer argument set to NULL. + * * Args: ctx: an existing context object (cannot be NULL) * In: fun: a pointer to a function to call when an illegal argument is - * passed to the API, taking a message and an opaque pointer - * (NULL restores a default handler that calls abort). + * passed to the API, taking a message and an opaque pointer. + * (NULL restores the default handler.) * data: the opaque pointer to pass to fun above. + * + * See also secp256k1_context_set_error_callback. */ SECP256K1_API void secp256k1_context_set_illegal_callback( secp256k1_context* ctx, @@ -253,9 +290,12 @@ SECP256K1_API void secp256k1_context_set_illegal_callback( * * Args: ctx: an existing context object (cannot be NULL) * In: fun: a pointer to a function to call when an internal error occurs, - * taking a message and an opaque pointer (NULL restores a default - * handler that calls abort). + * taking a message and an opaque pointer (NULL restores the + * default handler, see secp256k1_context_set_illegal_callback + * for details). * data: the opaque pointer to pass to fun above. + * + * See also secp256k1_context_set_illegal_callback. */ SECP256K1_API void secp256k1_context_set_error_callback( secp256k1_context* ctx, @@ -267,21 +307,24 @@ SECP256K1_API void secp256k1_context_set_error_callback( * * Returns: a newly created scratch space. * Args: ctx: an existing context object (cannot be NULL) - * In: max_size: maximum amount of memory to allocate + * In: size: amount of memory to be available as scratch space. Some extra + * (<100 bytes) will be allocated for extra accounting. */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT secp256k1_scratch_space* secp256k1_scratch_space_create( const secp256k1_context* ctx, - size_t max_size + size_t size ) SECP256K1_ARG_NONNULL(1); /** Destroy a secp256k1 scratch space. * * The pointer may not be used afterwards. - * Args: scratch: space to destroy + * Args: ctx: a secp256k1 context object. + * scratch: space to destroy */ SECP256K1_API void secp256k1_scratch_space_destroy( + const secp256k1_context* ctx, secp256k1_scratch_space* scratch -); +) SECP256K1_ARG_NONNULL(1); /** Parse a variable-length public key into the pubkey object. * @@ -488,7 +531,7 @@ SECP256K1_API extern const secp256k1_nonce_function secp256k1_nonce_function_def /** Create an ECDSA signature. * * Returns: 1: signature created - * 0: the nonce generation function failed, or the private key was invalid. + * 0: the nonce generation function failed, or the secret key was invalid. * Args: ctx: pointer to a context object, initialized for signing (cannot be NULL) * Out: sig: pointer to an array where the signature will be placed (cannot be NULL) * In: msg32: the 32-byte message hash being signed (cannot be NULL) @@ -510,6 +553,11 @@ SECP256K1_API int secp256k1_ecdsa_sign( /** Verify an ECDSA secret key. * + * A secret key is valid if it is not 0 and less than the secp256k1 curve order + * when interpreted as an integer (most significant byte first). The + * probability of choosing a 32-byte string uniformly at random which is an + * invalid secret key is negligible. + * * Returns: 1: secret key is valid * 0: secret key is invalid * Args: ctx: pointer to a context object (cannot be NULL) @@ -526,7 +574,7 @@ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_seckey_verify( * 0: secret was invalid, try again * Args: ctx: pointer to a context object, initialized for signing (cannot be NULL) * Out: pubkey: pointer to the created public key (cannot be NULL) - * In: seckey: pointer to a 32-byte private key (cannot be NULL) + * In: seckey: pointer to a 32-byte secret key (cannot be NULL) */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_pubkey_create( const secp256k1_context* ctx, @@ -534,12 +582,24 @@ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_pubkey_create( const unsigned char *seckey ) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2) SECP256K1_ARG_NONNULL(3); -/** Negates a private key in place. +/** Negates a secret key in place. * - * Returns: 1 always - * Args: ctx: pointer to a context object - * In/Out: seckey: pointer to the 32-byte private key to be negated (cannot be NULL) + * Returns: 0 if the given secret key is invalid according to + * secp256k1_ec_seckey_verify. 1 otherwise + * Args: ctx: pointer to a context object + * In/Out: seckey: pointer to the 32-byte secret key to be negated. If the + * secret key is invalid according to + * secp256k1_ec_seckey_verify, this function returns 0 and + * seckey will be set to some unspecified value. (cannot be + * NULL) */ +SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_seckey_negate( + const secp256k1_context* ctx, + unsigned char *seckey +) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2); + +/** Same as secp256k1_ec_seckey_negate, but DEPRECATED. Will be removed in + * future versions. */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_privkey_negate( const secp256k1_context* ctx, unsigned char *seckey @@ -556,15 +616,29 @@ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_pubkey_negate( secp256k1_pubkey *pubkey ) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2); -/** Tweak a private key by adding tweak to it. - * Returns: 0 if the tweak was out of range (chance of around 1 in 2^128 for - * uniformly random 32-byte arrays, or if the resulting private key - * would be invalid (only when the tweak is the complement of the - * private key). 1 otherwise. - * Args: ctx: pointer to a context object (cannot be NULL). - * In/Out: seckey: pointer to a 32-byte private key. - * In: tweak: pointer to a 32-byte tweak. - */ +/** Tweak a secret key by adding tweak to it. + * + * Returns: 0 if the arguments are invalid or the resulting secret key would be + * invalid (only when the tweak is the negation of the secret key). 1 + * otherwise. + * Args: ctx: pointer to a context object (cannot be NULL). + * In/Out: seckey: pointer to a 32-byte secret key. If the secret key is + * invalid according to secp256k1_ec_seckey_verify, this + * function returns 0. seckey will be set to some unspecified + * value if this function returns 0. (cannot be NULL) + * In: tweak: pointer to a 32-byte tweak. If the tweak is invalid according to + * secp256k1_ec_seckey_verify, this function returns 0. For + * uniformly random 32-byte arrays the chance of being invalid + * is negligible (around 1 in 2^128) (cannot be NULL). + */ +SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_seckey_tweak_add( + const secp256k1_context* ctx, + unsigned char *seckey, + const unsigned char *tweak +) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2) SECP256K1_ARG_NONNULL(3); + +/** Same as secp256k1_ec_seckey_tweak_add, but DEPRECATED. Will be removed in + * future versions. */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_privkey_tweak_add( const secp256k1_context* ctx, unsigned char *seckey, @@ -572,14 +646,18 @@ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_privkey_tweak_add( ) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2) SECP256K1_ARG_NONNULL(3); /** Tweak a public key by adding tweak times the generator to it. - * Returns: 0 if the tweak was out of range (chance of around 1 in 2^128 for - * uniformly random 32-byte arrays, or if the resulting public key - * would be invalid (only when the tweak is the complement of the - * corresponding private key). 1 otherwise. - * Args: ctx: pointer to a context object initialized for validation + * + * Returns: 0 if the arguments are invalid or the resulting public key would be + * invalid (only when the tweak is the negation of the corresponding + * secret key). 1 otherwise. + * Args: ctx: pointer to a context object initialized for validation * (cannot be NULL). - * In/Out: pubkey: pointer to a public key object. - * In: tweak: pointer to a 32-byte tweak. + * In/Out: pubkey: pointer to a public key object. pubkey will be set to an + * invalid value if this function returns 0 (cannot be NULL). + * In: tweak: pointer to a 32-byte tweak. If the tweak is invalid according to + * secp256k1_ec_seckey_verify, this function returns 0. For + * uniformly random 32-byte arrays the chance of being invalid + * is negligible (around 1 in 2^128) (cannot be NULL). */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_pubkey_tweak_add( const secp256k1_context* ctx, @@ -587,13 +665,27 @@ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_pubkey_tweak_add( const unsigned char *tweak ) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2) SECP256K1_ARG_NONNULL(3); -/** Tweak a private key by multiplying it by a tweak. - * Returns: 0 if the tweak was out of range (chance of around 1 in 2^128 for - * uniformly random 32-byte arrays, or equal to zero. 1 otherwise. - * Args: ctx: pointer to a context object (cannot be NULL). - * In/Out: seckey: pointer to a 32-byte private key. - * In: tweak: pointer to a 32-byte tweak. +/** Tweak a secret key by multiplying it by a tweak. + * + * Returns: 0 if the arguments are invalid. 1 otherwise. + * Args: ctx: pointer to a context object (cannot be NULL). + * In/Out: seckey: pointer to a 32-byte secret key. If the secret key is + * invalid according to secp256k1_ec_seckey_verify, this + * function returns 0. seckey will be set to some unspecified + * value if this function returns 0. (cannot be NULL) + * In: tweak: pointer to a 32-byte tweak. If the tweak is invalid according to + * secp256k1_ec_seckey_verify, this function returns 0. For + * uniformly random 32-byte arrays the chance of being invalid + * is negligible (around 1 in 2^128) (cannot be NULL). */ +SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_seckey_tweak_mul( + const secp256k1_context* ctx, + unsigned char *seckey, + const unsigned char *tweak +) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2) SECP256K1_ARG_NONNULL(3); + +/** Same as secp256k1_ec_seckey_tweak_mul, but DEPRECATED. Will be removed in + * future versions. */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_privkey_tweak_mul( const secp256k1_context* ctx, unsigned char *seckey, @@ -601,12 +693,16 @@ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_privkey_tweak_mul( ) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2) SECP256K1_ARG_NONNULL(3); /** Tweak a public key by multiplying it by a tweak value. - * Returns: 0 if the tweak was out of range (chance of around 1 in 2^128 for - * uniformly random 32-byte arrays, or equal to zero. 1 otherwise. - * Args: ctx: pointer to a context object initialized for validation - * (cannot be NULL). - * In/Out: pubkey: pointer to a public key obkect. - * In: tweak: pointer to a 32-byte tweak. + * + * Returns: 0 if the arguments are invalid. 1 otherwise. + * Args: ctx: pointer to a context object initialized for validation + * (cannot be NULL). + * In/Out: pubkey: pointer to a public key object. pubkey will be set to an + * invalid value if this function returns 0 (cannot be NULL). + * In: tweak: pointer to a 32-byte tweak. If the tweak is invalid according to + * secp256k1_ec_seckey_verify, this function returns 0. For + * uniformly random 32-byte arrays the chance of being invalid + * is negligible (around 1 in 2^128) (cannot be NULL). */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_pubkey_tweak_mul( const secp256k1_context* ctx, @@ -636,7 +732,8 @@ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ec_pubkey_tweak_mul( * contexts not initialized for signing; then it will have no effect and return 1. * * You should call this after secp256k1_context_create or - * secp256k1_context_clone, and may call this repeatedly afterwards. + * secp256k1_context_clone (and secp256k1_context_preallocated_create or + * secp256k1_context_clone, resp.), and you may call this repeatedly afterwards. */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_context_randomize( secp256k1_context* ctx, @@ -644,6 +741,7 @@ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_context_randomize( ) SECP256K1_ARG_NONNULL(1); /** Add a number of public keys together. + * * Returns: 1: the sum of the public keys is valid. * 0: the sum of the public keys is not valid. * Args: ctx: pointer to a context object diff --git a/src/secp256k1/include/secp256k1_ecdh.h b/src/secp256k1/include/secp256k1_ecdh.h index df5fde235c..4058e9c043 100644 --- a/src/secp256k1/include/secp256k1_ecdh.h +++ b/src/secp256k1/include/secp256k1_ecdh.h @@ -7,43 +7,50 @@ extern "C" { #endif -/** A pointer to a function that applies hash function to a point +/** A pointer to a function that hashes an EC point to obtain an ECDH secret * - * Returns: 1 if a point was successfully hashed. 0 will cause ecdh to fail - * Out: output: pointer to an array to be filled by the function - * In: x: pointer to a 32-byte x coordinate - * y: pointer to a 32-byte y coordinate - * data: Arbitrary data pointer that is passed through + * Returns: 1 if the point was successfully hashed. + * 0 will cause secp256k1_ecdh to fail and return 0. + * Other return values are not allowed, and the behaviour of + * secp256k1_ecdh is undefined for other return values. + * Out: output: pointer to an array to be filled by the function + * In: x32: pointer to a 32-byte x coordinate + * y32: pointer to a 32-byte y coordinate + * data: arbitrary data pointer that is passed through */ typedef int (*secp256k1_ecdh_hash_function)( unsigned char *output, - const unsigned char *x, - const unsigned char *y, + const unsigned char *x32, + const unsigned char *y32, void *data ); -/** An implementation of SHA256 hash function that applies to compressed public key. */ +/** An implementation of SHA256 hash function that applies to compressed public key. + * Populates the output parameter with 32 bytes. */ SECP256K1_API extern const secp256k1_ecdh_hash_function secp256k1_ecdh_hash_function_sha256; -/** A default ecdh hash function (currently equal to secp256k1_ecdh_hash_function_sha256). */ +/** A default ECDH hash function (currently equal to secp256k1_ecdh_hash_function_sha256). + * Populates the output parameter with 32 bytes. */ SECP256K1_API extern const secp256k1_ecdh_hash_function secp256k1_ecdh_hash_function_default; /** Compute an EC Diffie-Hellman secret in constant time + * * Returns: 1: exponentiation was successful - * 0: scalar was invalid (zero or overflow) + * 0: scalar was invalid (zero or overflow) or hashfp returned 0 * Args: ctx: pointer to a context object (cannot be NULL) - * Out: output: pointer to an array to be filled by the function + * Out: output: pointer to an array to be filled by hashfp * In: pubkey: a pointer to a secp256k1_pubkey containing an * initialized public key - * privkey: a 32-byte scalar with which to multiply the point + * seckey: a 32-byte scalar with which to multiply the point * hashfp: pointer to a hash function. If NULL, secp256k1_ecdh_hash_function_sha256 is used - * data: Arbitrary data pointer that is passed through + * (in which case, 32 bytes will be written to output) + * data: arbitrary data pointer that is passed through to hashfp */ SECP256K1_API SECP256K1_WARN_UNUSED_RESULT int secp256k1_ecdh( const secp256k1_context* ctx, unsigned char *output, const secp256k1_pubkey *pubkey, - const unsigned char *privkey, + const unsigned char *seckey, secp256k1_ecdh_hash_function hashfp, void *data ) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2) SECP256K1_ARG_NONNULL(3) SECP256K1_ARG_NONNULL(4); diff --git a/src/secp256k1/include/secp256k1_preallocated.h b/src/secp256k1/include/secp256k1_preallocated.h new file mode 100644 index 0000000000..a9ae15d5ae --- /dev/null +++ b/src/secp256k1/include/secp256k1_preallocated.h @@ -0,0 +1,128 @@ +#ifndef SECP256K1_PREALLOCATED_H +#define SECP256K1_PREALLOCATED_H + +#include "secp256k1.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* The module provided by this header file is intended for settings in which it + * is not possible or desirable to rely on dynamic memory allocation. It provides + * functions for creating, cloning, and destroying secp256k1 context objects in a + * contiguous fixed-size block of memory provided by the caller. + * + * Context objects created by functions in this module can be used like contexts + * objects created by functions in secp256k1.h, i.e., they can be passed to any + * API function that expects a context object (see secp256k1.h for details). The + * only exception is that context objects created by functions in this module + * must be destroyed using secp256k1_context_preallocated_destroy (in this + * module) instead of secp256k1_context_destroy (in secp256k1.h). + * + * It is guaranteed that functions in this module will not call malloc or its + * friends realloc, calloc, and free. + */ + +/** Determine the memory size of a secp256k1 context object to be created in + * caller-provided memory. + * + * The purpose of this function is to determine how much memory must be provided + * to secp256k1_context_preallocated_create. + * + * Returns: the required size of the caller-provided memory block + * In: flags: which parts of the context to initialize. + */ +SECP256K1_API size_t secp256k1_context_preallocated_size( + unsigned int flags +) SECP256K1_WARN_UNUSED_RESULT; + +/** Create a secp256k1 context object in caller-provided memory. + * + * The caller must provide a pointer to a rewritable contiguous block of memory + * of size at least secp256k1_context_preallocated_size(flags) bytes, suitably + * aligned to hold an object of any type. + * + * The block of memory is exclusively owned by the created context object during + * the lifetime of this context object, which begins with the call to this + * function and ends when a call to secp256k1_context_preallocated_destroy + * (which destroys the context object again) returns. During the lifetime of the + * context object, the caller is obligated not to access this block of memory, + * i.e., the caller may not read or write the memory, e.g., by copying the memory + * contents to a different location or trying to create a second context object + * in the memory. In simpler words, the prealloc pointer (or any pointer derived + * from it) should not be used during the lifetime of the context object. + * + * Returns: a newly created context object. + * In: prealloc: a pointer to a rewritable contiguous block of memory of + * size at least secp256k1_context_preallocated_size(flags) + * bytes, as detailed above (cannot be NULL) + * flags: which parts of the context to initialize. + * + * See also secp256k1_context_randomize (in secp256k1.h) + * and secp256k1_context_preallocated_destroy. + */ +SECP256K1_API secp256k1_context* secp256k1_context_preallocated_create( + void* prealloc, + unsigned int flags +) SECP256K1_ARG_NONNULL(1) SECP256K1_WARN_UNUSED_RESULT; + +/** Determine the memory size of a secp256k1 context object to be copied into + * caller-provided memory. + * + * Returns: the required size of the caller-provided memory block. + * In: ctx: an existing context to copy (cannot be NULL) + */ +SECP256K1_API size_t secp256k1_context_preallocated_clone_size( + const secp256k1_context* ctx +) SECP256K1_ARG_NONNULL(1) SECP256K1_WARN_UNUSED_RESULT; + +/** Copy a secp256k1 context object into caller-provided memory. + * + * The caller must provide a pointer to a rewritable contiguous block of memory + * of size at least secp256k1_context_preallocated_size(flags) bytes, suitably + * aligned to hold an object of any type. + * + * The block of memory is exclusively owned by the created context object during + * the lifetime of this context object, see the description of + * secp256k1_context_preallocated_create for details. + * + * Returns: a newly created context object. + * Args: ctx: an existing context to copy (cannot be NULL) + * In: prealloc: a pointer to a rewritable contiguous block of memory of + * size at least secp256k1_context_preallocated_size(flags) + * bytes, as detailed above (cannot be NULL) + */ +SECP256K1_API secp256k1_context* secp256k1_context_preallocated_clone( + const secp256k1_context* ctx, + void* prealloc +) SECP256K1_ARG_NONNULL(1) SECP256K1_ARG_NONNULL(2) SECP256K1_WARN_UNUSED_RESULT; + +/** Destroy a secp256k1 context object that has been created in + * caller-provided memory. + * + * The context pointer may not be used afterwards. + * + * The context to destroy must have been created using + * secp256k1_context_preallocated_create or secp256k1_context_preallocated_clone. + * If the context has instead been created using secp256k1_context_create or + * secp256k1_context_clone, the behaviour is undefined. In that case, + * secp256k1_context_destroy must be used instead. + * + * If required, it is the responsibility of the caller to deallocate the block + * of memory properly after this function returns, e.g., by calling free on the + * preallocated pointer given to secp256k1_context_preallocated_create or + * secp256k1_context_preallocated_clone. + * + * Args: ctx: an existing context to destroy, constructed using + * secp256k1_context_preallocated_create or + * secp256k1_context_preallocated_clone (cannot be NULL) + */ +SECP256K1_API void secp256k1_context_preallocated_destroy( + secp256k1_context* ctx +); + +#ifdef __cplusplus +} +#endif + +#endif /* SECP256K1_PREALLOCATED_H */ diff --git a/src/secp256k1/include/secp256k1_recovery.h b/src/secp256k1/include/secp256k1_recovery.h index cf6c5ed7f5..f8ccaecd3d 100644 --- a/src/secp256k1/include/secp256k1_recovery.h +++ b/src/secp256k1/include/secp256k1_recovery.h @@ -70,7 +70,7 @@ SECP256K1_API int secp256k1_ecdsa_recoverable_signature_serialize_compact( /** Create a recoverable ECDSA signature. * * Returns: 1: signature created - * 0: the nonce generation function failed, or the private key was invalid. + * 0: the nonce generation function failed, or the secret key was invalid. * Args: ctx: pointer to a context object, initialized for signing (cannot be NULL) * Out: sig: pointer to an array where the signature will be placed (cannot be NULL) * In: msg32: the 32-byte message hash being signed (cannot be NULL) diff --git a/src/secp256k1/src/asm/field_10x26_arm.s b/src/secp256k1/src/asm/field_10x26_arm.s index 5a9cc3ffcf..9a5bd06721 100644 --- a/src/secp256k1/src/asm/field_10x26_arm.s +++ b/src/secp256k1/src/asm/field_10x26_arm.s @@ -16,15 +16,9 @@ Note: */ .syntax unified - .arch armv7-a @ eabi attributes - see readelf -A - .eabi_attribute 8, 1 @ Tag_ARM_ISA_use = yes - .eabi_attribute 9, 0 @ Tag_Thumb_ISA_use = no - .eabi_attribute 10, 0 @ Tag_FP_arch = none .eabi_attribute 24, 1 @ Tag_ABI_align_needed = 8-byte .eabi_attribute 25, 1 @ Tag_ABI_align_preserved = 8-byte, except leaf SP - .eabi_attribute 30, 2 @ Tag_ABI_optimization_goals = Aggressive Speed - .eabi_attribute 34, 1 @ Tag_CPU_unaligned_access = v6 .text @ Field constants diff --git a/src/secp256k1/src/basic-config.h b/src/secp256k1/src/basic-config.h index fc588061ca..e9be39d4ca 100644 --- a/src/secp256k1/src/basic-config.h +++ b/src/secp256k1/src/basic-config.h @@ -10,7 +10,10 @@ #ifdef USE_BASIC_CONFIG #undef USE_ASM_X86_64 +#undef USE_ECMULT_STATIC_PRECOMPUTATION #undef USE_ENDOMORPHISM +#undef USE_EXTERNAL_ASM +#undef USE_EXTERNAL_DEFAULT_CALLBACKS #undef USE_FIELD_10X26 #undef USE_FIELD_5X52 #undef USE_FIELD_INV_BUILTIN @@ -21,12 +24,14 @@ #undef USE_SCALAR_8X32 #undef USE_SCALAR_INV_BUILTIN #undef USE_SCALAR_INV_NUM +#undef ECMULT_WINDOW_SIZE #define USE_NUM_NONE 1 #define USE_FIELD_INV_BUILTIN 1 #define USE_SCALAR_INV_BUILTIN 1 #define USE_FIELD_10X26 1 #define USE_SCALAR_8X32 1 +#define ECMULT_WINDOW_SIZE 15 #endif /* USE_BASIC_CONFIG */ diff --git a/src/secp256k1/src/bench.h b/src/secp256k1/src/bench.h index 5b59783f68..9bfed903e0 100644 --- a/src/secp256k1/src/bench.h +++ b/src/secp256k1/src/bench.h @@ -7,45 +7,87 @@ #ifndef SECP256K1_BENCH_H #define SECP256K1_BENCH_H +#include <stdint.h> #include <stdio.h> #include <string.h> -#include <math.h> #include "sys/time.h" -static double gettimedouble(void) { +static int64_t gettime_i64(void) { struct timeval tv; gettimeofday(&tv, NULL); - return tv.tv_usec * 0.000001 + tv.tv_sec; + return (int64_t)tv.tv_usec + (int64_t)tv.tv_sec * 1000000LL; } -void print_number(double x) { - double y = x; - int c = 0; - if (y < 0.0) { - y = -y; +#define FP_EXP (6) +#define FP_MULT (1000000LL) + +/* Format fixed point number. */ +void print_number(const int64_t x) { + int64_t x_abs, y; + int c, i, rounding; + size_t ptr; + char buffer[30]; + + if (x == INT64_MIN) { + /* Prevent UB. */ + printf("ERR"); + return; } - while (y > 0 && y < 100.0) { - y *= 10.0; + x_abs = x < 0 ? -x : x; + + /* Determine how many decimals we want to show (more than FP_EXP makes no + * sense). */ + y = x_abs; + c = 0; + while (y > 0LL && y < 100LL * FP_MULT && c < FP_EXP) { + y *= 10LL; c++; } - printf("%.*f", c, x); + + /* Round to 'c' decimals. */ + y = x_abs; + rounding = 0; + for (i = c; i < FP_EXP; ++i) { + rounding = (y % 10) >= 5; + y /= 10; + } + y += rounding; + + /* Format and print the number. */ + ptr = sizeof(buffer) - 1; + buffer[ptr] = 0; + if (c != 0) { + for (i = 0; i < c; ++i) { + buffer[--ptr] = '0' + (y % 10); + y /= 10; + } + buffer[--ptr] = '.'; + } + do { + buffer[--ptr] = '0' + (y % 10); + y /= 10; + } while (y != 0); + if (x < 0) { + buffer[--ptr] = '-'; + } + printf("%s", &buffer[ptr]); } -void run_benchmark(char *name, void (*benchmark)(void*), void (*setup)(void*), void (*teardown)(void*), void* data, int count, int iter) { +void run_benchmark(char *name, void (*benchmark)(void*, int), void (*setup)(void*), void (*teardown)(void*, int), void* data, int count, int iter) { int i; - double min = HUGE_VAL; - double sum = 0.0; - double max = 0.0; + int64_t min = INT64_MAX; + int64_t sum = 0; + int64_t max = 0; for (i = 0; i < count; i++) { - double begin, total; + int64_t begin, total; if (setup != NULL) { setup(data); } - begin = gettimedouble(); - benchmark(data); - total = gettimedouble() - begin; + begin = gettime_i64(); + benchmark(data, iter); + total = gettime_i64() - begin; if (teardown != NULL) { - teardown(data); + teardown(data, iter); } if (total < min) { min = total; @@ -56,11 +98,11 @@ void run_benchmark(char *name, void (*benchmark)(void*), void (*setup)(void*), v sum += total; } printf("%s: min ", name); - print_number(min * 1000000.0 / iter); + print_number(min * FP_MULT / iter); printf("us / avg "); - print_number((sum / count) * 1000000.0 / iter); + print_number(((sum * FP_MULT) / count) / iter); printf("us / max "); - print_number(max * 1000000.0 / iter); + print_number(max * FP_MULT / iter); printf("us\n"); } @@ -79,4 +121,13 @@ int have_flag(int argc, char** argv, char *flag) { return 0; } +int get_iters(int default_iters) { + char* env = getenv("SECP256K1_BENCH_ITERS"); + if (env) { + return strtol(env, NULL, 0); + } else { + return default_iters; + } +} + #endif /* SECP256K1_BENCH_H */ diff --git a/src/secp256k1/src/bench_ecdh.c b/src/secp256k1/src/bench_ecdh.c index c1dd5a6ac9..f099d33884 100644 --- a/src/secp256k1/src/bench_ecdh.c +++ b/src/secp256k1/src/bench_ecdh.c @@ -28,20 +28,18 @@ static void bench_ecdh_setup(void* arg) { 0xa2, 0xba, 0xd1, 0x84, 0xf8, 0x83, 0xc6, 0x9f }; - /* create a context with no capabilities */ - data->ctx = secp256k1_context_create(SECP256K1_FLAGS_TYPE_CONTEXT); for (i = 0; i < 32; i++) { data->scalar[i] = i + 1; } CHECK(secp256k1_ec_pubkey_parse(data->ctx, &data->point, point, sizeof(point)) == 1); } -static void bench_ecdh(void* arg) { +static void bench_ecdh(void* arg, int iters) { int i; unsigned char res[32]; bench_ecdh_data *data = (bench_ecdh_data*)arg; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { CHECK(secp256k1_ecdh(data->ctx, res, &data->point, data->scalar, NULL, NULL) == 1); } } @@ -49,6 +47,13 @@ static void bench_ecdh(void* arg) { int main(void) { bench_ecdh_data data; - run_benchmark("ecdh", bench_ecdh, bench_ecdh_setup, NULL, &data, 10, 20000); + int iters = get_iters(20000); + + /* create a context with no capabilities */ + data.ctx = secp256k1_context_create(SECP256K1_FLAGS_TYPE_CONTEXT); + + run_benchmark("ecdh", bench_ecdh, bench_ecdh_setup, NULL, &data, 10, iters); + + secp256k1_context_destroy(data.ctx); return 0; } diff --git a/src/secp256k1/src/bench_ecmult.c b/src/secp256k1/src/bench_ecmult.c index 6d0ed1f436..facd07ef31 100644 --- a/src/secp256k1/src/bench_ecmult.c +++ b/src/secp256k1/src/bench_ecmult.c @@ -18,7 +18,6 @@ #include "secp256k1.c" #define POINTS 32768 -#define ITERS 10000 typedef struct { /* Setup once in advance */ @@ -55,16 +54,16 @@ static int bench_callback(secp256k1_scalar* sc, secp256k1_ge* ge, size_t idx, vo return 1; } -static void bench_ecmult(void* arg) { +static void bench_ecmult(void* arg, int iters) { bench_data* data = (bench_data*)arg; - size_t count = data->count; int includes_g = data->includes_g; - size_t iters = 1 + ITERS / count; - size_t iter; + int iter; + int count = data->count; + iters = iters / data->count; for (iter = 0; iter < iters; ++iter) { - data->ecmult_multi(&data->ctx->ecmult_ctx, data->scratch, &data->output[iter], data->includes_g ? &data->scalars[data->offset1] : NULL, bench_callback, arg, count - includes_g); + data->ecmult_multi(&data->ctx->error_callback, &data->ctx->ecmult_ctx, data->scratch, &data->output[iter], data->includes_g ? &data->scalars[data->offset1] : NULL, bench_callback, arg, count - includes_g); data->offset1 = (data->offset1 + count) % POINTS; data->offset2 = (data->offset2 + count - 1) % POINTS; } @@ -76,10 +75,10 @@ static void bench_ecmult_setup(void* arg) { data->offset2 = (data->count * 0x7f6f537b + 0x6a1a8f49) % POINTS; } -static void bench_ecmult_teardown(void* arg) { +static void bench_ecmult_teardown(void* arg, int iters) { bench_data* data = (bench_data*)arg; - size_t iters = 1 + ITERS / data->count; - size_t iter; + int iter; + iters = iters / data->count; /* Verify the results in teardown, to avoid doing comparisons while benchmarking. */ for (iter = 0; iter < iters; ++iter) { secp256k1_gej tmp; @@ -104,10 +103,10 @@ static void generate_scalar(uint32_t num, secp256k1_scalar* scalar) { CHECK(!overflow); } -static void run_test(bench_data* data, size_t count, int includes_g) { +static void run_test(bench_data* data, size_t count, int includes_g, int num_iters) { char str[32]; static const secp256k1_scalar zero = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 0, 0, 0); - size_t iters = 1 + ITERS / count; + size_t iters = 1 + num_iters / count; size_t iter; data->count = count; @@ -130,7 +129,7 @@ static void run_test(bench_data* data, size_t count, int includes_g) { /* Run the benchmark. */ sprintf(str, includes_g ? "ecmult_%ig" : "ecmult_%i", (int)count); - run_benchmark(str, bench_ecmult, bench_ecmult_setup, bench_ecmult_teardown, data, 10, count * (1 + ITERS / count)); + run_benchmark(str, bench_ecmult, bench_ecmult_setup, bench_ecmult_teardown, data, 10, count * iters); } int main(int argc, char **argv) { @@ -139,6 +138,8 @@ int main(int argc, char **argv) { secp256k1_gej* pubkeys_gej; size_t scratch_size; + int iters = get_iters(10000); + data.ctx = secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); scratch_size = secp256k1_strauss_scratch_size(POINTS) + STRAUSS_SCRATCH_OBJECTS*16; data.scratch = secp256k1_scratch_space_create(data.ctx, scratch_size); @@ -154,7 +155,7 @@ int main(int argc, char **argv) { } else if(have_flag(argc, argv, "simple")) { printf("Using simple algorithm:\n"); data.ecmult_multi = secp256k1_ecmult_multi_var; - secp256k1_scratch_space_destroy(data.scratch); + secp256k1_scratch_space_destroy(data.ctx, data.scratch); data.scratch = NULL; } else { fprintf(stderr, "%s: unrecognized argument '%s'.\n", argv[0], argv[1]); @@ -167,8 +168,8 @@ int main(int argc, char **argv) { data.scalars = malloc(sizeof(secp256k1_scalar) * POINTS); data.seckeys = malloc(sizeof(secp256k1_scalar) * POINTS); data.pubkeys = malloc(sizeof(secp256k1_ge) * POINTS); - data.expected_output = malloc(sizeof(secp256k1_gej) * (ITERS + 1)); - data.output = malloc(sizeof(secp256k1_gej) * (ITERS + 1)); + data.expected_output = malloc(sizeof(secp256k1_gej) * (iters + 1)); + data.output = malloc(sizeof(secp256k1_gej) * (iters + 1)); /* Generate a set of scalars, and private/public keypairs. */ pubkeys_gej = malloc(sizeof(secp256k1_gej) * POINTS); @@ -185,18 +186,24 @@ int main(int argc, char **argv) { free(pubkeys_gej); for (i = 1; i <= 8; ++i) { - run_test(&data, i, 1); + run_test(&data, i, 1, iters); } - for (p = 0; p <= 11; ++p) { - for (i = 9; i <= 16; ++i) { - run_test(&data, i << p, 1); + /* This is disabled with low count of iterations because the loop runs 77 times even with iters=1 + * and the higher it goes the longer the computation takes(more points) + * So we don't run this benchmark with low iterations to prevent slow down */ + if (iters > 2) { + for (p = 0; p <= 11; ++p) { + for (i = 9; i <= 16; ++i) { + run_test(&data, i << p, 1, iters); + } } } - secp256k1_context_destroy(data.ctx); + if (data.scratch != NULL) { - secp256k1_scratch_space_destroy(data.scratch); + secp256k1_scratch_space_destroy(data.ctx, data.scratch); } + secp256k1_context_destroy(data.ctx); free(data.scalars); free(data.pubkeys); free(data.seckeys); diff --git a/src/secp256k1/src/bench_internal.c b/src/secp256k1/src/bench_internal.c index 9071724331..20759127d3 100644 --- a/src/secp256k1/src/bench_internal.c +++ b/src/secp256k1/src/bench_internal.c @@ -56,263 +56,272 @@ void bench_setup(void* arg) { memcpy(data->data + 32, init_y, 32); } -void bench_scalar_add(void* arg) { - int i; +void bench_scalar_add(void* arg, int iters) { + int i, j = 0; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 2000000; i++) { - secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); + for (i = 0; i < iters; i++) { + j += secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); } + CHECK(j <= iters); } -void bench_scalar_negate(void* arg) { +void bench_scalar_negate(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 2000000; i++) { + for (i = 0; i < iters; i++) { secp256k1_scalar_negate(&data->scalar_x, &data->scalar_x); } } -void bench_scalar_sqr(void* arg) { +void bench_scalar_sqr(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 200000; i++) { + for (i = 0; i < iters; i++) { secp256k1_scalar_sqr(&data->scalar_x, &data->scalar_x); } } -void bench_scalar_mul(void* arg) { +void bench_scalar_mul(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 200000; i++) { + for (i = 0; i < iters; i++) { secp256k1_scalar_mul(&data->scalar_x, &data->scalar_x, &data->scalar_y); } } #ifdef USE_ENDOMORPHISM -void bench_scalar_split(void* arg) { - int i; +void bench_scalar_split(void* arg, int iters) { + int i, j = 0; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 20000; i++) { - secp256k1_scalar l, r; - secp256k1_scalar_split_lambda(&l, &r, &data->scalar_x); - secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); + for (i = 0; i < iters; i++) { + secp256k1_scalar_split_lambda(&data->scalar_x, &data->scalar_y, &data->scalar_x); + j += secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); } + CHECK(j <= iters); } #endif -void bench_scalar_inverse(void* arg) { - int i; +void bench_scalar_inverse(void* arg, int iters) { + int i, j = 0; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 2000; i++) { + for (i = 0; i < iters; i++) { secp256k1_scalar_inverse(&data->scalar_x, &data->scalar_x); - secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); + j += secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); } + CHECK(j <= iters); } -void bench_scalar_inverse_var(void* arg) { - int i; +void bench_scalar_inverse_var(void* arg, int iters) { + int i, j = 0; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 2000; i++) { + for (i = 0; i < iters; i++) { secp256k1_scalar_inverse_var(&data->scalar_x, &data->scalar_x); - secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); + j += secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); } + CHECK(j <= iters); } -void bench_field_normalize(void* arg) { +void bench_field_normalize(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 2000000; i++) { + for (i = 0; i < iters; i++) { secp256k1_fe_normalize(&data->fe_x); } } -void bench_field_normalize_weak(void* arg) { +void bench_field_normalize_weak(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 2000000; i++) { + for (i = 0; i < iters; i++) { secp256k1_fe_normalize_weak(&data->fe_x); } } -void bench_field_mul(void* arg) { +void bench_field_mul(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 200000; i++) { + for (i = 0; i < iters; i++) { secp256k1_fe_mul(&data->fe_x, &data->fe_x, &data->fe_y); } } -void bench_field_sqr(void* arg) { +void bench_field_sqr(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 200000; i++) { + for (i = 0; i < iters; i++) { secp256k1_fe_sqr(&data->fe_x, &data->fe_x); } } -void bench_field_inverse(void* arg) { +void bench_field_inverse(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { secp256k1_fe_inv(&data->fe_x, &data->fe_x); secp256k1_fe_add(&data->fe_x, &data->fe_y); } } -void bench_field_inverse_var(void* arg) { +void bench_field_inverse_var(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { secp256k1_fe_inv_var(&data->fe_x, &data->fe_x); secp256k1_fe_add(&data->fe_x, &data->fe_y); } } -void bench_field_sqrt(void* arg) { - int i; +void bench_field_sqrt(void* arg, int iters) { + int i, j = 0; bench_inv *data = (bench_inv*)arg; secp256k1_fe t; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { t = data->fe_x; - secp256k1_fe_sqrt(&data->fe_x, &t); + j += secp256k1_fe_sqrt(&data->fe_x, &t); secp256k1_fe_add(&data->fe_x, &data->fe_y); } + CHECK(j <= iters); } -void bench_group_double_var(void* arg) { +void bench_group_double_var(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 200000; i++) { + for (i = 0; i < iters; i++) { secp256k1_gej_double_var(&data->gej_x, &data->gej_x, NULL); } } -void bench_group_add_var(void* arg) { +void bench_group_add_var(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 200000; i++) { + for (i = 0; i < iters; i++) { secp256k1_gej_add_var(&data->gej_x, &data->gej_x, &data->gej_y, NULL); } } -void bench_group_add_affine(void* arg) { +void bench_group_add_affine(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 200000; i++) { + for (i = 0; i < iters; i++) { secp256k1_gej_add_ge(&data->gej_x, &data->gej_x, &data->ge_y); } } -void bench_group_add_affine_var(void* arg) { +void bench_group_add_affine_var(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 200000; i++) { + for (i = 0; i < iters; i++) { secp256k1_gej_add_ge_var(&data->gej_x, &data->gej_x, &data->ge_y, NULL); } } -void bench_group_jacobi_var(void* arg) { - int i; +void bench_group_jacobi_var(void* arg, int iters) { + int i, j = 0; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 20000; i++) { - secp256k1_gej_has_quad_y_var(&data->gej_x); + for (i = 0; i < iters; i++) { + j += secp256k1_gej_has_quad_y_var(&data->gej_x); } + CHECK(j == iters); } -void bench_ecmult_wnaf(void* arg) { - int i; +void bench_ecmult_wnaf(void* arg, int iters) { + int i, bits = 0, overflow = 0; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 20000; i++) { - secp256k1_ecmult_wnaf(data->wnaf, 256, &data->scalar_x, WINDOW_A); - secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); + for (i = 0; i < iters; i++) { + bits += secp256k1_ecmult_wnaf(data->wnaf, 256, &data->scalar_x, WINDOW_A); + overflow += secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); } + CHECK(overflow >= 0); + CHECK(bits <= 256*iters); } -void bench_wnaf_const(void* arg) { - int i; +void bench_wnaf_const(void* arg, int iters) { + int i, bits = 0, overflow = 0; bench_inv *data = (bench_inv*)arg; - for (i = 0; i < 20000; i++) { - secp256k1_wnaf_const(data->wnaf, data->scalar_x, WINDOW_A, 256); - secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); + for (i = 0; i < iters; i++) { + bits += secp256k1_wnaf_const(data->wnaf, &data->scalar_x, WINDOW_A, 256); + overflow += secp256k1_scalar_add(&data->scalar_x, &data->scalar_x, &data->scalar_y); } + CHECK(overflow >= 0); + CHECK(bits <= 256*iters); } -void bench_sha256(void* arg) { +void bench_sha256(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; secp256k1_sha256 sha; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { secp256k1_sha256_initialize(&sha); secp256k1_sha256_write(&sha, data->data, 32); secp256k1_sha256_finalize(&sha, data->data); } } -void bench_hmac_sha256(void* arg) { +void bench_hmac_sha256(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; secp256k1_hmac_sha256 hmac; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { secp256k1_hmac_sha256_initialize(&hmac, data->data, 32); secp256k1_hmac_sha256_write(&hmac, data->data, 32); secp256k1_hmac_sha256_finalize(&hmac, data->data); } } -void bench_rfc6979_hmac_sha256(void* arg) { +void bench_rfc6979_hmac_sha256(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; secp256k1_rfc6979_hmac_sha256 rng; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { secp256k1_rfc6979_hmac_sha256_initialize(&rng, data->data, 64); secp256k1_rfc6979_hmac_sha256_generate(&rng, data->data, 32); } } -void bench_context_verify(void* arg) { +void bench_context_verify(void* arg, int iters) { int i; (void)arg; - for (i = 0; i < 20; i++) { + for (i = 0; i < iters; i++) { secp256k1_context_destroy(secp256k1_context_create(SECP256K1_CONTEXT_VERIFY)); } } -void bench_context_sign(void* arg) { +void bench_context_sign(void* arg, int iters) { int i; (void)arg; - for (i = 0; i < 200; i++) { + for (i = 0; i < iters; i++) { secp256k1_context_destroy(secp256k1_context_create(SECP256K1_CONTEXT_SIGN)); } } #ifndef USE_NUM_NONE -void bench_num_jacobi(void* arg) { - int i; +void bench_num_jacobi(void* arg, int iters) { + int i, j = 0; bench_inv *data = (bench_inv*)arg; secp256k1_num nx, norder; @@ -320,50 +329,53 @@ void bench_num_jacobi(void* arg) { secp256k1_scalar_order_get_num(&norder); secp256k1_scalar_get_num(&norder, &data->scalar_y); - for (i = 0; i < 200000; i++) { - secp256k1_num_jacobi(&nx, &norder); + for (i = 0; i < iters; i++) { + j += secp256k1_num_jacobi(&nx, &norder); } + CHECK(j <= iters); } #endif int main(int argc, char **argv) { bench_inv data; - if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "add")) run_benchmark("scalar_add", bench_scalar_add, bench_setup, NULL, &data, 10, 2000000); - if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "negate")) run_benchmark("scalar_negate", bench_scalar_negate, bench_setup, NULL, &data, 10, 2000000); - if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "sqr")) run_benchmark("scalar_sqr", bench_scalar_sqr, bench_setup, NULL, &data, 10, 200000); - if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "mul")) run_benchmark("scalar_mul", bench_scalar_mul, bench_setup, NULL, &data, 10, 200000); + int iters = get_iters(20000); + + if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "add")) run_benchmark("scalar_add", bench_scalar_add, bench_setup, NULL, &data, 10, iters*100); + if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "negate")) run_benchmark("scalar_negate", bench_scalar_negate, bench_setup, NULL, &data, 10, iters*100); + if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "sqr")) run_benchmark("scalar_sqr", bench_scalar_sqr, bench_setup, NULL, &data, 10, iters*10); + if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "mul")) run_benchmark("scalar_mul", bench_scalar_mul, bench_setup, NULL, &data, 10, iters*10); #ifdef USE_ENDOMORPHISM - if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "split")) run_benchmark("scalar_split", bench_scalar_split, bench_setup, NULL, &data, 10, 20000); + if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "split")) run_benchmark("scalar_split", bench_scalar_split, bench_setup, NULL, &data, 10, iters); #endif if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "inverse")) run_benchmark("scalar_inverse", bench_scalar_inverse, bench_setup, NULL, &data, 10, 2000); if (have_flag(argc, argv, "scalar") || have_flag(argc, argv, "inverse")) run_benchmark("scalar_inverse_var", bench_scalar_inverse_var, bench_setup, NULL, &data, 10, 2000); - if (have_flag(argc, argv, "field") || have_flag(argc, argv, "normalize")) run_benchmark("field_normalize", bench_field_normalize, bench_setup, NULL, &data, 10, 2000000); - if (have_flag(argc, argv, "field") || have_flag(argc, argv, "normalize")) run_benchmark("field_normalize_weak", bench_field_normalize_weak, bench_setup, NULL, &data, 10, 2000000); - if (have_flag(argc, argv, "field") || have_flag(argc, argv, "sqr")) run_benchmark("field_sqr", bench_field_sqr, bench_setup, NULL, &data, 10, 200000); - if (have_flag(argc, argv, "field") || have_flag(argc, argv, "mul")) run_benchmark("field_mul", bench_field_mul, bench_setup, NULL, &data, 10, 200000); - if (have_flag(argc, argv, "field") || have_flag(argc, argv, "inverse")) run_benchmark("field_inverse", bench_field_inverse, bench_setup, NULL, &data, 10, 20000); - if (have_flag(argc, argv, "field") || have_flag(argc, argv, "inverse")) run_benchmark("field_inverse_var", bench_field_inverse_var, bench_setup, NULL, &data, 10, 20000); - if (have_flag(argc, argv, "field") || have_flag(argc, argv, "sqrt")) run_benchmark("field_sqrt", bench_field_sqrt, bench_setup, NULL, &data, 10, 20000); + if (have_flag(argc, argv, "field") || have_flag(argc, argv, "normalize")) run_benchmark("field_normalize", bench_field_normalize, bench_setup, NULL, &data, 10, iters*100); + if (have_flag(argc, argv, "field") || have_flag(argc, argv, "normalize")) run_benchmark("field_normalize_weak", bench_field_normalize_weak, bench_setup, NULL, &data, 10, iters*100); + if (have_flag(argc, argv, "field") || have_flag(argc, argv, "sqr")) run_benchmark("field_sqr", bench_field_sqr, bench_setup, NULL, &data, 10, iters*10); + if (have_flag(argc, argv, "field") || have_flag(argc, argv, "mul")) run_benchmark("field_mul", bench_field_mul, bench_setup, NULL, &data, 10, iters*10); + if (have_flag(argc, argv, "field") || have_flag(argc, argv, "inverse")) run_benchmark("field_inverse", bench_field_inverse, bench_setup, NULL, &data, 10, iters); + if (have_flag(argc, argv, "field") || have_flag(argc, argv, "inverse")) run_benchmark("field_inverse_var", bench_field_inverse_var, bench_setup, NULL, &data, 10, iters); + if (have_flag(argc, argv, "field") || have_flag(argc, argv, "sqrt")) run_benchmark("field_sqrt", bench_field_sqrt, bench_setup, NULL, &data, 10, iters); - if (have_flag(argc, argv, "group") || have_flag(argc, argv, "double")) run_benchmark("group_double_var", bench_group_double_var, bench_setup, NULL, &data, 10, 200000); - if (have_flag(argc, argv, "group") || have_flag(argc, argv, "add")) run_benchmark("group_add_var", bench_group_add_var, bench_setup, NULL, &data, 10, 200000); - if (have_flag(argc, argv, "group") || have_flag(argc, argv, "add")) run_benchmark("group_add_affine", bench_group_add_affine, bench_setup, NULL, &data, 10, 200000); - if (have_flag(argc, argv, "group") || have_flag(argc, argv, "add")) run_benchmark("group_add_affine_var", bench_group_add_affine_var, bench_setup, NULL, &data, 10, 200000); - if (have_flag(argc, argv, "group") || have_flag(argc, argv, "jacobi")) run_benchmark("group_jacobi_var", bench_group_jacobi_var, bench_setup, NULL, &data, 10, 20000); + if (have_flag(argc, argv, "group") || have_flag(argc, argv, "double")) run_benchmark("group_double_var", bench_group_double_var, bench_setup, NULL, &data, 10, iters*10); + if (have_flag(argc, argv, "group") || have_flag(argc, argv, "add")) run_benchmark("group_add_var", bench_group_add_var, bench_setup, NULL, &data, 10, iters*10); + if (have_flag(argc, argv, "group") || have_flag(argc, argv, "add")) run_benchmark("group_add_affine", bench_group_add_affine, bench_setup, NULL, &data, 10, iters*10); + if (have_flag(argc, argv, "group") || have_flag(argc, argv, "add")) run_benchmark("group_add_affine_var", bench_group_add_affine_var, bench_setup, NULL, &data, 10, iters*10); + if (have_flag(argc, argv, "group") || have_flag(argc, argv, "jacobi")) run_benchmark("group_jacobi_var", bench_group_jacobi_var, bench_setup, NULL, &data, 10, iters); - if (have_flag(argc, argv, "ecmult") || have_flag(argc, argv, "wnaf")) run_benchmark("wnaf_const", bench_wnaf_const, bench_setup, NULL, &data, 10, 20000); - if (have_flag(argc, argv, "ecmult") || have_flag(argc, argv, "wnaf")) run_benchmark("ecmult_wnaf", bench_ecmult_wnaf, bench_setup, NULL, &data, 10, 20000); + if (have_flag(argc, argv, "ecmult") || have_flag(argc, argv, "wnaf")) run_benchmark("wnaf_const", bench_wnaf_const, bench_setup, NULL, &data, 10, iters); + if (have_flag(argc, argv, "ecmult") || have_flag(argc, argv, "wnaf")) run_benchmark("ecmult_wnaf", bench_ecmult_wnaf, bench_setup, NULL, &data, 10, iters); - if (have_flag(argc, argv, "hash") || have_flag(argc, argv, "sha256")) run_benchmark("hash_sha256", bench_sha256, bench_setup, NULL, &data, 10, 20000); - if (have_flag(argc, argv, "hash") || have_flag(argc, argv, "hmac")) run_benchmark("hash_hmac_sha256", bench_hmac_sha256, bench_setup, NULL, &data, 10, 20000); - if (have_flag(argc, argv, "hash") || have_flag(argc, argv, "rng6979")) run_benchmark("hash_rfc6979_hmac_sha256", bench_rfc6979_hmac_sha256, bench_setup, NULL, &data, 10, 20000); + if (have_flag(argc, argv, "hash") || have_flag(argc, argv, "sha256")) run_benchmark("hash_sha256", bench_sha256, bench_setup, NULL, &data, 10, iters); + if (have_flag(argc, argv, "hash") || have_flag(argc, argv, "hmac")) run_benchmark("hash_hmac_sha256", bench_hmac_sha256, bench_setup, NULL, &data, 10, iters); + if (have_flag(argc, argv, "hash") || have_flag(argc, argv, "rng6979")) run_benchmark("hash_rfc6979_hmac_sha256", bench_rfc6979_hmac_sha256, bench_setup, NULL, &data, 10, iters); - if (have_flag(argc, argv, "context") || have_flag(argc, argv, "verify")) run_benchmark("context_verify", bench_context_verify, bench_setup, NULL, &data, 10, 20); - if (have_flag(argc, argv, "context") || have_flag(argc, argv, "sign")) run_benchmark("context_sign", bench_context_sign, bench_setup, NULL, &data, 10, 200); + if (have_flag(argc, argv, "context") || have_flag(argc, argv, "verify")) run_benchmark("context_verify", bench_context_verify, bench_setup, NULL, &data, 10, 1 + iters/1000); + if (have_flag(argc, argv, "context") || have_flag(argc, argv, "sign")) run_benchmark("context_sign", bench_context_sign, bench_setup, NULL, &data, 10, 1 + iters/100); #ifndef USE_NUM_NONE - if (have_flag(argc, argv, "num") || have_flag(argc, argv, "jacobi")) run_benchmark("num_jacobi", bench_num_jacobi, bench_setup, NULL, &data, 10, 200000); + if (have_flag(argc, argv, "num") || have_flag(argc, argv, "jacobi")) run_benchmark("num_jacobi", bench_num_jacobi, bench_setup, NULL, &data, 10, iters*10); #endif return 0; } diff --git a/src/secp256k1/src/bench_recover.c b/src/secp256k1/src/bench_recover.c index b806eed94e..e952ed1215 100644 --- a/src/secp256k1/src/bench_recover.c +++ b/src/secp256k1/src/bench_recover.c @@ -15,13 +15,13 @@ typedef struct { unsigned char sig[64]; } bench_recover_data; -void bench_recover(void* arg) { +void bench_recover(void* arg, int iters) { int i; bench_recover_data *data = (bench_recover_data*)arg; secp256k1_pubkey pubkey; unsigned char pubkeyc[33]; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { int j; size_t pubkeylen = 33; secp256k1_ecdsa_recoverable_signature sig; @@ -51,9 +51,11 @@ void bench_recover_setup(void* arg) { int main(void) { bench_recover_data data; + int iters = get_iters(20000); + data.ctx = secp256k1_context_create(SECP256K1_CONTEXT_VERIFY); - run_benchmark("ecdsa_recover", bench_recover, bench_recover_setup, NULL, &data, 10, 20000); + run_benchmark("ecdsa_recover", bench_recover, bench_recover_setup, NULL, &data, 10, iters); secp256k1_context_destroy(data.ctx); return 0; diff --git a/src/secp256k1/src/bench_sign.c b/src/secp256k1/src/bench_sign.c index 544b43963c..c6b2942cc0 100644 --- a/src/secp256k1/src/bench_sign.c +++ b/src/secp256k1/src/bench_sign.c @@ -26,12 +26,12 @@ static void bench_sign_setup(void* arg) { } } -static void bench_sign_run(void* arg) { +static void bench_sign_run(void* arg, int iters) { int i; bench_sign *data = (bench_sign*)arg; unsigned char sig[74]; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { size_t siglen = 74; int j; secp256k1_ecdsa_signature signature; @@ -47,9 +47,11 @@ static void bench_sign_run(void* arg) { int main(void) { bench_sign data; + int iters = get_iters(20000); + data.ctx = secp256k1_context_create(SECP256K1_CONTEXT_SIGN); - run_benchmark("ecdsa_sign", bench_sign_run, bench_sign_setup, NULL, &data, 10, 20000); + run_benchmark("ecdsa_sign", bench_sign_run, bench_sign_setup, NULL, &data, 10, iters); secp256k1_context_destroy(data.ctx); return 0; diff --git a/src/secp256k1/src/bench_verify.c b/src/secp256k1/src/bench_verify.c index 418defa0aa..272d3e5cc4 100644 --- a/src/secp256k1/src/bench_verify.c +++ b/src/secp256k1/src/bench_verify.c @@ -17,6 +17,7 @@ #include <openssl/obj_mac.h> #endif + typedef struct { secp256k1_context *ctx; unsigned char msg[32]; @@ -30,11 +31,11 @@ typedef struct { #endif } benchmark_verify_t; -static void benchmark_verify(void* arg) { +static void benchmark_verify(void* arg, int iters) { int i; benchmark_verify_t* data = (benchmark_verify_t*)arg; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { secp256k1_pubkey pubkey; secp256k1_ecdsa_signature sig; data->sig[data->siglen - 1] ^= (i & 0xFF); @@ -50,11 +51,11 @@ static void benchmark_verify(void* arg) { } #ifdef ENABLE_OPENSSL_TESTS -static void benchmark_verify_openssl(void* arg) { +static void benchmark_verify_openssl(void* arg, int iters) { int i; benchmark_verify_t* data = (benchmark_verify_t*)arg; - for (i = 0; i < 20000; i++) { + for (i = 0; i < iters; i++) { data->sig[data->siglen - 1] ^= (i & 0xFF); data->sig[data->siglen - 2] ^= ((i >> 8) & 0xFF); data->sig[data->siglen - 3] ^= ((i >> 16) & 0xFF); @@ -85,6 +86,8 @@ int main(void) { secp256k1_ecdsa_signature sig; benchmark_verify_t data; + int iters = get_iters(20000); + data.ctx = secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); for (i = 0; i < 32; i++) { @@ -100,10 +103,10 @@ int main(void) { data.pubkeylen = 33; CHECK(secp256k1_ec_pubkey_serialize(data.ctx, data.pubkey, &data.pubkeylen, &pubkey, SECP256K1_EC_COMPRESSED) == 1); - run_benchmark("ecdsa_verify", benchmark_verify, NULL, NULL, &data, 10, 20000); + run_benchmark("ecdsa_verify", benchmark_verify, NULL, NULL, &data, 10, iters); #ifdef ENABLE_OPENSSL_TESTS data.ec_group = EC_GROUP_new_by_curve_name(NID_secp256k1); - run_benchmark("ecdsa_verify_openssl", benchmark_verify_openssl, NULL, NULL, &data, 10, 20000); + run_benchmark("ecdsa_verify_openssl", benchmark_verify_openssl, NULL, NULL, &data, 10, iters); EC_GROUP_free(data.ec_group); #endif diff --git a/src/secp256k1/src/ecdsa_impl.h b/src/secp256k1/src/ecdsa_impl.h index c3400042d8..5f54b59faa 100644 --- a/src/secp256k1/src/ecdsa_impl.h +++ b/src/secp256k1/src/ecdsa_impl.h @@ -46,68 +46,73 @@ static const secp256k1_fe secp256k1_ecdsa_const_p_minus_order = SECP256K1_FE_CON 0, 0, 0, 1, 0x45512319UL, 0x50B75FC4UL, 0x402DA172UL, 0x2FC9BAEEUL ); -static int secp256k1_der_read_len(const unsigned char **sigp, const unsigned char *sigend) { - int lenleft, b1; - size_t ret = 0; +static int secp256k1_der_read_len(size_t *len, const unsigned char **sigp, const unsigned char *sigend) { + size_t lenleft; + unsigned char b1; + VERIFY_CHECK(len != NULL); + *len = 0; if (*sigp >= sigend) { - return -1; + return 0; } b1 = *((*sigp)++); if (b1 == 0xFF) { /* X.690-0207 8.1.3.5.c the value 0xFF shall not be used. */ - return -1; + return 0; } if ((b1 & 0x80) == 0) { /* X.690-0207 8.1.3.4 short form length octets */ - return b1; + *len = b1; + return 1; } if (b1 == 0x80) { /* Indefinite length is not allowed in DER. */ - return -1; + return 0; } /* X.690-207 8.1.3.5 long form length octets */ - lenleft = b1 & 0x7F; - if (lenleft > sigend - *sigp) { - return -1; + lenleft = b1 & 0x7F; /* lenleft is at least 1 */ + if (lenleft > (size_t)(sigend - *sigp)) { + return 0; } if (**sigp == 0) { /* Not the shortest possible length encoding. */ - return -1; + return 0; } - if ((size_t)lenleft > sizeof(size_t)) { + if (lenleft > sizeof(size_t)) { /* The resulting length would exceed the range of a size_t, so * certainly longer than the passed array size. */ - return -1; + return 0; } while (lenleft > 0) { - ret = (ret << 8) | **sigp; - if (ret + lenleft > (size_t)(sigend - *sigp)) { - /* Result exceeds the length of the passed array. */ - return -1; - } + *len = (*len << 8) | **sigp; (*sigp)++; lenleft--; } - if (ret < 128) { + if (*len > (size_t)(sigend - *sigp)) { + /* Result exceeds the length of the passed array. */ + return 0; + } + if (*len < 128) { /* Not the shortest possible length encoding. */ - return -1; + return 0; } - return ret; + return 1; } static int secp256k1_der_parse_integer(secp256k1_scalar *r, const unsigned char **sig, const unsigned char *sigend) { int overflow = 0; unsigned char ra[32] = {0}; - int rlen; + size_t rlen; if (*sig == sigend || **sig != 0x02) { /* Not a primitive integer (X.690-0207 8.3.1). */ return 0; } (*sig)++; - rlen = secp256k1_der_read_len(sig, sigend); - if (rlen <= 0 || (*sig) + rlen > sigend) { + if (secp256k1_der_read_len(&rlen, sig, sigend) == 0) { + return 0; + } + if (rlen == 0 || *sig + rlen > sigend) { /* Exceeds bounds or not at least length 1 (X.690-0207 8.3.1). */ return 0; } @@ -123,8 +128,11 @@ static int secp256k1_der_parse_integer(secp256k1_scalar *r, const unsigned char /* Negative. */ overflow = 1; } - while (rlen > 0 && **sig == 0) { - /* Skip leading zero bytes */ + /* There is at most one leading zero byte: + * if there were two leading zero bytes, we would have failed and returned 0 + * because of excessive 0x00 padding already. */ + if (rlen > 0 && **sig == 0) { + /* Skip leading zero byte */ rlen--; (*sig)++; } @@ -144,18 +152,16 @@ static int secp256k1_der_parse_integer(secp256k1_scalar *r, const unsigned char static int secp256k1_ecdsa_sig_parse(secp256k1_scalar *rr, secp256k1_scalar *rs, const unsigned char *sig, size_t size) { const unsigned char *sigend = sig + size; - int rlen; + size_t rlen; if (sig == sigend || *(sig++) != 0x30) { /* The encoding doesn't start with a constructed sequence (X.690-0207 8.9.1). */ return 0; } - rlen = secp256k1_der_read_len(&sig, sigend); - if (rlen < 0 || sig + rlen > sigend) { - /* Tuple exceeds bounds */ + if (secp256k1_der_read_len(&rlen, &sig, sigend) == 0) { return 0; } - if (sig + rlen != sigend) { - /* Garbage after tuple. */ + if (rlen != (size_t)(sigend - sig)) { + /* Tuple exceeds bounds or garage after tuple. */ return 0; } @@ -274,6 +280,7 @@ static int secp256k1_ecdsa_sig_sign(const secp256k1_ecmult_gen_context *ctx, sec secp256k1_ge r; secp256k1_scalar n; int overflow = 0; + int high; secp256k1_ecmult_gen(ctx, &rp, nonce); secp256k1_ge_set_gej(&r, &rp); @@ -281,15 +288,11 @@ static int secp256k1_ecdsa_sig_sign(const secp256k1_ecmult_gen_context *ctx, sec secp256k1_fe_normalize(&r.y); secp256k1_fe_get_b32(b, &r.x); secp256k1_scalar_set_b32(sigr, b, &overflow); - /* These two conditions should be checked before calling */ - VERIFY_CHECK(!secp256k1_scalar_is_zero(sigr)); - VERIFY_CHECK(overflow == 0); - if (recid) { /* The overflow condition is cryptographically unreachable as hitting it requires finding the discrete log * of some P where P.x >= order, and only 1 in about 2^127 points meet this criteria. */ - *recid = (overflow ? 2 : 0) | (secp256k1_fe_is_odd(&r.y) ? 1 : 0); + *recid = (overflow << 1) | secp256k1_fe_is_odd(&r.y); } secp256k1_scalar_mul(&n, sigr, seckey); secp256k1_scalar_add(&n, &n, message); @@ -298,16 +301,15 @@ static int secp256k1_ecdsa_sig_sign(const secp256k1_ecmult_gen_context *ctx, sec secp256k1_scalar_clear(&n); secp256k1_gej_clear(&rp); secp256k1_ge_clear(&r); - if (secp256k1_scalar_is_zero(sigs)) { - return 0; - } - if (secp256k1_scalar_is_high(sigs)) { - secp256k1_scalar_negate(sigs, sigs); - if (recid) { - *recid ^= 1; - } + high = secp256k1_scalar_is_high(sigs); + secp256k1_scalar_cond_negate(sigs, high); + if (recid) { + *recid ^= high; } - return 1; + /* P.x = order is on the curve, so technically sig->r could end up being zero, which would be an invalid signature. + * This is cryptographically unreachable as hitting it requires finding the discrete log of P.x = N. + */ + return !secp256k1_scalar_is_zero(sigr) & !secp256k1_scalar_is_zero(sigs); } #endif /* SECP256K1_ECDSA_IMPL_H */ diff --git a/src/secp256k1/src/eckey_impl.h b/src/secp256k1/src/eckey_impl.h index 7c5b789325..e2e72d9303 100644 --- a/src/secp256k1/src/eckey_impl.h +++ b/src/secp256k1/src/eckey_impl.h @@ -54,10 +54,7 @@ static int secp256k1_eckey_pubkey_serialize(secp256k1_ge *elem, unsigned char *p static int secp256k1_eckey_privkey_tweak_add(secp256k1_scalar *key, const secp256k1_scalar *tweak) { secp256k1_scalar_add(key, key, tweak); - if (secp256k1_scalar_is_zero(key)) { - return 0; - } - return 1; + return !secp256k1_scalar_is_zero(key); } static int secp256k1_eckey_pubkey_tweak_add(const secp256k1_ecmult_context *ctx, secp256k1_ge *key, const secp256k1_scalar *tweak) { @@ -75,12 +72,11 @@ static int secp256k1_eckey_pubkey_tweak_add(const secp256k1_ecmult_context *ctx, } static int secp256k1_eckey_privkey_tweak_mul(secp256k1_scalar *key, const secp256k1_scalar *tweak) { - if (secp256k1_scalar_is_zero(tweak)) { - return 0; - } + int ret; + ret = !secp256k1_scalar_is_zero(tweak); secp256k1_scalar_mul(key, key, tweak); - return 1; + return ret; } static int secp256k1_eckey_pubkey_tweak_mul(const secp256k1_ecmult_context *ctx, secp256k1_ge *key, const secp256k1_scalar *tweak) { diff --git a/src/secp256k1/src/ecmult.h b/src/secp256k1/src/ecmult.h index 3d75a960f4..c9b198239d 100644 --- a/src/secp256k1/src/ecmult.h +++ b/src/secp256k1/src/ecmult.h @@ -20,10 +20,10 @@ typedef struct { #endif } secp256k1_ecmult_context; +static const size_t SECP256K1_ECMULT_CONTEXT_PREALLOCATED_SIZE; static void secp256k1_ecmult_context_init(secp256k1_ecmult_context *ctx); -static void secp256k1_ecmult_context_build(secp256k1_ecmult_context *ctx, const secp256k1_callback *cb); -static void secp256k1_ecmult_context_clone(secp256k1_ecmult_context *dst, - const secp256k1_ecmult_context *src, const secp256k1_callback *cb); +static void secp256k1_ecmult_context_build(secp256k1_ecmult_context *ctx, void **prealloc); +static void secp256k1_ecmult_context_finalize_memcpy(secp256k1_ecmult_context *dst, const secp256k1_ecmult_context *src); static void secp256k1_ecmult_context_clear(secp256k1_ecmult_context *ctx); static int secp256k1_ecmult_context_is_built(const secp256k1_ecmult_context *ctx); @@ -43,6 +43,6 @@ typedef int (secp256k1_ecmult_multi_callback)(secp256k1_scalar *sc, secp256k1_ge * 0 if there is not enough scratch space for a single point or * callback returns 0 */ -static int secp256k1_ecmult_multi_var(const secp256k1_ecmult_context *ctx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n); +static int secp256k1_ecmult_multi_var(const secp256k1_callback* error_callback, const secp256k1_ecmult_context *ctx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n); #endif /* SECP256K1_ECMULT_H */ diff --git a/src/secp256k1/src/ecmult_const.h b/src/secp256k1/src/ecmult_const.h index d4804b8b68..03bb33257d 100644 --- a/src/secp256k1/src/ecmult_const.h +++ b/src/secp256k1/src/ecmult_const.h @@ -10,8 +10,11 @@ #include "scalar.h" #include "group.h" -/* Here `bits` should be set to the maximum bitlength of the _absolute value_ of `q`, plus - * one because we internally sometimes add 2 to the number during the WNAF conversion. */ +/** + * Multiply: R = q*A (in constant-time) + * Here `bits` should be set to the maximum bitlength of the _absolute value_ of `q`, plus + * one because we internally sometimes add 2 to the number during the WNAF conversion. + */ static void secp256k1_ecmult_const(secp256k1_gej *r, const secp256k1_ge *a, const secp256k1_scalar *q, int bits); #endif /* SECP256K1_ECMULT_CONST_H */ diff --git a/src/secp256k1/src/ecmult_const_impl.h b/src/secp256k1/src/ecmult_const_impl.h index 8411752eb0..6d6d354aa4 100644 --- a/src/secp256k1/src/ecmult_const_impl.h +++ b/src/secp256k1/src/ecmult_const_impl.h @@ -14,16 +14,22 @@ /* This is like `ECMULT_TABLE_GET_GE` but is constant time */ #define ECMULT_CONST_TABLE_GET_GE(r,pre,n,w) do { \ - int m; \ - int abs_n = (n) * (((n) > 0) * 2 - 1); \ - int idx_n = abs_n / 2; \ + int m = 0; \ + /* Extract the sign-bit for a constant time absolute-value. */ \ + int mask = (n) >> (sizeof(n) * CHAR_BIT - 1); \ + int abs_n = ((n) + mask) ^ mask; \ + int idx_n = abs_n >> 1; \ secp256k1_fe neg_y; \ VERIFY_CHECK(((n) & 1) == 1); \ VERIFY_CHECK((n) >= -((1 << ((w)-1)) - 1)); \ VERIFY_CHECK((n) <= ((1 << ((w)-1)) - 1)); \ VERIFY_SETUP(secp256k1_fe_clear(&(r)->x)); \ VERIFY_SETUP(secp256k1_fe_clear(&(r)->y)); \ - for (m = 0; m < ECMULT_TABLE_SIZE(w); m++) { \ + /* Unconditionally set r->x = (pre)[m].x. r->y = (pre)[m].y. because it's either the correct one \ + * or will get replaced in the later iterations, this is needed to make sure `r` is initialized. */ \ + (r)->x = (pre)[m].x; \ + (r)->y = (pre)[m].y; \ + for (m = 1; m < ECMULT_TABLE_SIZE(w); m++) { \ /* This loop is used to avoid secret data in array indices. See * the comment in ecmult_gen_impl.h for rationale. */ \ secp256k1_fe_cmov(&(r)->x, &(pre)[m].x, m == idx_n); \ @@ -44,11 +50,11 @@ * * Adapted from `The Width-w NAF Method Provides Small Memory and Fast Elliptic Scalar * Multiplications Secure against Side Channel Attacks`, Okeya and Tagaki. M. Joye (Ed.) - * CT-RSA 2003, LNCS 2612, pp. 328-443, 2003. Springer-Verlagy Berlin Heidelberg 2003 + * CT-RSA 2003, LNCS 2612, pp. 328-443, 2003. Springer-Verlag Berlin Heidelberg 2003 * * Numbers reference steps of `Algorithm SPA-resistant Width-w NAF with Odd Scalar` on pp. 335 */ -static int secp256k1_wnaf_const(int *wnaf, secp256k1_scalar s, int w, int size) { +static int secp256k1_wnaf_const(int *wnaf, const secp256k1_scalar *scalar, int w, int size) { int global_sign; int skew = 0; int word = 0; @@ -59,8 +65,12 @@ static int secp256k1_wnaf_const(int *wnaf, secp256k1_scalar s, int w, int size) int flip; int bit; - secp256k1_scalar neg_s; + secp256k1_scalar s; int not_neg_one; + + VERIFY_CHECK(w > 0); + VERIFY_CHECK(size > 0); + /* Note that we cannot handle even numbers by negating them to be odd, as is * done in other implementations, since if our scalars were specified to have * width < 256 for performance reasons, their negations would have width 256 @@ -75,12 +85,13 @@ static int secp256k1_wnaf_const(int *wnaf, secp256k1_scalar s, int w, int size) * {1, 2} we want to add to the scalar when ensuring that it's odd. Further * complicating things, -1 interacts badly with `secp256k1_scalar_cadd_bit` and * we need to special-case it in this logic. */ - flip = secp256k1_scalar_is_high(&s); + flip = secp256k1_scalar_is_high(scalar); /* We add 1 to even numbers, 2 to odd ones, noting that negation flips parity */ - bit = flip ^ !secp256k1_scalar_is_even(&s); + bit = flip ^ !secp256k1_scalar_is_even(scalar); /* We check for negative one, since adding 2 to it will cause an overflow */ - secp256k1_scalar_negate(&neg_s, &s); - not_neg_one = !secp256k1_scalar_is_one(&neg_s); + secp256k1_scalar_negate(&s, scalar); + not_neg_one = !secp256k1_scalar_is_one(&s); + s = *scalar; secp256k1_scalar_cadd_bit(&s, bit, not_neg_one); /* If we had negative one, flip == 1, s.d[0] == 0, bit == 1, so caller expects * that we added two to it and flipped it. In fact for -1 these operations are @@ -93,7 +104,7 @@ static int secp256k1_wnaf_const(int *wnaf, secp256k1_scalar s, int w, int size) /* 4 */ u_last = secp256k1_scalar_shr_int(&s, w); - while (word * w < size) { + do { int sign; int even; @@ -109,7 +120,7 @@ static int secp256k1_wnaf_const(int *wnaf, secp256k1_scalar s, int w, int size) wnaf[word++] = u_last * global_sign; u_last = u; - } + } while (word * w < size); wnaf[word] = u * global_sign; VERIFY_CHECK(secp256k1_scalar_is_zero(&s)); @@ -132,7 +143,6 @@ static void secp256k1_ecmult_const(secp256k1_gej *r, const secp256k1_ge *a, cons int wnaf_1[1 + WNAF_SIZE(WINDOW_A - 1)]; int i; - secp256k1_scalar sc = *scalar; /* build wnaf representation for q. */ int rsize = size; @@ -140,13 +150,13 @@ static void secp256k1_ecmult_const(secp256k1_gej *r, const secp256k1_ge *a, cons if (size > 128) { rsize = 128; /* split q into q_1 and q_lam (where q = q_1 + q_lam*lambda, and q_1 and q_lam are ~128 bit) */ - secp256k1_scalar_split_lambda(&q_1, &q_lam, &sc); - skew_1 = secp256k1_wnaf_const(wnaf_1, q_1, WINDOW_A - 1, 128); - skew_lam = secp256k1_wnaf_const(wnaf_lam, q_lam, WINDOW_A - 1, 128); + secp256k1_scalar_split_lambda(&q_1, &q_lam, scalar); + skew_1 = secp256k1_wnaf_const(wnaf_1, &q_1, WINDOW_A - 1, 128); + skew_lam = secp256k1_wnaf_const(wnaf_lam, &q_lam, WINDOW_A - 1, 128); } else #endif { - skew_1 = secp256k1_wnaf_const(wnaf_1, sc, WINDOW_A - 1, size); + skew_1 = secp256k1_wnaf_const(wnaf_1, scalar, WINDOW_A - 1, size); #ifdef USE_ENDOMORPHISM skew_lam = 0; #endif @@ -168,6 +178,7 @@ static void secp256k1_ecmult_const(secp256k1_gej *r, const secp256k1_ge *a, cons for (i = 0; i < ECMULT_TABLE_SIZE(WINDOW_A); i++) { secp256k1_ge_mul_lambda(&pre_a_lam[i], &pre_a[i]); } + } #endif @@ -191,7 +202,7 @@ static void secp256k1_ecmult_const(secp256k1_gej *r, const secp256k1_ge *a, cons int n; int j; for (j = 0; j < WINDOW_A - 1; ++j) { - secp256k1_gej_double_nonzero(r, r, NULL); + secp256k1_gej_double_nonzero(r, r); } n = wnaf_1[i]; diff --git a/src/secp256k1/src/ecmult_gen.h b/src/secp256k1/src/ecmult_gen.h index 7564b7015f..30815e5aa1 100644 --- a/src/secp256k1/src/ecmult_gen.h +++ b/src/secp256k1/src/ecmult_gen.h @@ -10,28 +10,35 @@ #include "scalar.h" #include "group.h" +#if ECMULT_GEN_PREC_BITS != 2 && ECMULT_GEN_PREC_BITS != 4 && ECMULT_GEN_PREC_BITS != 8 +# error "Set ECMULT_GEN_PREC_BITS to 2, 4 or 8." +#endif +#define ECMULT_GEN_PREC_B ECMULT_GEN_PREC_BITS +#define ECMULT_GEN_PREC_G (1 << ECMULT_GEN_PREC_B) +#define ECMULT_GEN_PREC_N (256 / ECMULT_GEN_PREC_B) + typedef struct { /* For accelerating the computation of a*G: * To harden against timing attacks, use the following mechanism: - * * Break up the multiplicand into groups of 4 bits, called n_0, n_1, n_2, ..., n_63. - * * Compute sum(n_i * 16^i * G + U_i, i=0..63), where: - * * U_i = U * 2^i (for i=0..62) - * * U_i = U * (1-2^63) (for i=63) - * where U is a point with no known corresponding scalar. Note that sum(U_i, i=0..63) = 0. - * For each i, and each of the 16 possible values of n_i, (n_i * 16^i * G + U_i) is - * precomputed (call it prec(i, n_i)). The formula now becomes sum(prec(i, n_i), i=0..63). + * * Break up the multiplicand into groups of PREC_B bits, called n_0, n_1, n_2, ..., n_(PREC_N-1). + * * Compute sum(n_i * (PREC_G)^i * G + U_i, i=0 ... PREC_N-1), where: + * * U_i = U * 2^i, for i=0 ... PREC_N-2 + * * U_i = U * (1-2^(PREC_N-1)), for i=PREC_N-1 + * where U is a point with no known corresponding scalar. Note that sum(U_i, i=0 ... PREC_N-1) = 0. + * For each i, and each of the PREC_G possible values of n_i, (n_i * (PREC_G)^i * G + U_i) is + * precomputed (call it prec(i, n_i)). The formula now becomes sum(prec(i, n_i), i=0 ... PREC_N-1). * None of the resulting prec group elements have a known scalar, and neither do any of * the intermediate sums while computing a*G. */ - secp256k1_ge_storage (*prec)[64][16]; /* prec[j][i] = 16^j * i * G + U_i */ + secp256k1_ge_storage (*prec)[ECMULT_GEN_PREC_N][ECMULT_GEN_PREC_G]; /* prec[j][i] = (PREC_G)^j * i * G + U_i */ secp256k1_scalar blind; secp256k1_gej initial; } secp256k1_ecmult_gen_context; +static const size_t SECP256K1_ECMULT_GEN_CONTEXT_PREALLOCATED_SIZE; static void secp256k1_ecmult_gen_context_init(secp256k1_ecmult_gen_context* ctx); -static void secp256k1_ecmult_gen_context_build(secp256k1_ecmult_gen_context* ctx, const secp256k1_callback* cb); -static void secp256k1_ecmult_gen_context_clone(secp256k1_ecmult_gen_context *dst, - const secp256k1_ecmult_gen_context* src, const secp256k1_callback* cb); +static void secp256k1_ecmult_gen_context_build(secp256k1_ecmult_gen_context* ctx, void **prealloc); +static void secp256k1_ecmult_gen_context_finalize_memcpy(secp256k1_ecmult_gen_context *dst, const secp256k1_ecmult_gen_context* src); static void secp256k1_ecmult_gen_context_clear(secp256k1_ecmult_gen_context* ctx); static int secp256k1_ecmult_gen_context_is_built(const secp256k1_ecmult_gen_context* ctx); diff --git a/src/secp256k1/src/ecmult_gen_impl.h b/src/secp256k1/src/ecmult_gen_impl.h index d64505dc00..30ac16518b 100644 --- a/src/secp256k1/src/ecmult_gen_impl.h +++ b/src/secp256k1/src/ecmult_gen_impl.h @@ -7,6 +7,7 @@ #ifndef SECP256K1_ECMULT_GEN_IMPL_H #define SECP256K1_ECMULT_GEN_IMPL_H +#include "util.h" #include "scalar.h" #include "group.h" #include "ecmult_gen.h" @@ -14,23 +15,32 @@ #ifdef USE_ECMULT_STATIC_PRECOMPUTATION #include "ecmult_static_context.h" #endif + +#ifndef USE_ECMULT_STATIC_PRECOMPUTATION + static const size_t SECP256K1_ECMULT_GEN_CONTEXT_PREALLOCATED_SIZE = ROUND_TO_ALIGN(sizeof(*((secp256k1_ecmult_gen_context*) NULL)->prec)); +#else + static const size_t SECP256K1_ECMULT_GEN_CONTEXT_PREALLOCATED_SIZE = 0; +#endif + static void secp256k1_ecmult_gen_context_init(secp256k1_ecmult_gen_context *ctx) { ctx->prec = NULL; } -static void secp256k1_ecmult_gen_context_build(secp256k1_ecmult_gen_context *ctx, const secp256k1_callback* cb) { +static void secp256k1_ecmult_gen_context_build(secp256k1_ecmult_gen_context *ctx, void **prealloc) { #ifndef USE_ECMULT_STATIC_PRECOMPUTATION - secp256k1_ge prec[1024]; + secp256k1_ge prec[ECMULT_GEN_PREC_N * ECMULT_GEN_PREC_G]; secp256k1_gej gj; secp256k1_gej nums_gej; int i, j; + size_t const prealloc_size = SECP256K1_ECMULT_GEN_CONTEXT_PREALLOCATED_SIZE; + void* const base = *prealloc; #endif if (ctx->prec != NULL) { return; } #ifndef USE_ECMULT_STATIC_PRECOMPUTATION - ctx->prec = (secp256k1_ge_storage (*)[64][16])checked_malloc(cb, sizeof(*ctx->prec)); + ctx->prec = (secp256k1_ge_storage (*)[ECMULT_GEN_PREC_N][ECMULT_GEN_PREC_G])manual_alloc(prealloc, prealloc_size, base, prealloc_size); /* get the generator */ secp256k1_gej_set_ge(&gj, &secp256k1_ge_const_g); @@ -54,39 +64,39 @@ static void secp256k1_ecmult_gen_context_build(secp256k1_ecmult_gen_context *ctx /* compute prec. */ { - secp256k1_gej precj[1024]; /* Jacobian versions of prec. */ + secp256k1_gej precj[ECMULT_GEN_PREC_N * ECMULT_GEN_PREC_G]; /* Jacobian versions of prec. */ secp256k1_gej gbase; secp256k1_gej numsbase; - gbase = gj; /* 16^j * G */ + gbase = gj; /* PREC_G^j * G */ numsbase = nums_gej; /* 2^j * nums. */ - for (j = 0; j < 64; j++) { - /* Set precj[j*16 .. j*16+15] to (numsbase, numsbase + gbase, ..., numsbase + 15*gbase). */ - precj[j*16] = numsbase; - for (i = 1; i < 16; i++) { - secp256k1_gej_add_var(&precj[j*16 + i], &precj[j*16 + i - 1], &gbase, NULL); + for (j = 0; j < ECMULT_GEN_PREC_N; j++) { + /* Set precj[j*PREC_G .. j*PREC_G+(PREC_G-1)] to (numsbase, numsbase + gbase, ..., numsbase + (PREC_G-1)*gbase). */ + precj[j*ECMULT_GEN_PREC_G] = numsbase; + for (i = 1; i < ECMULT_GEN_PREC_G; i++) { + secp256k1_gej_add_var(&precj[j*ECMULT_GEN_PREC_G + i], &precj[j*ECMULT_GEN_PREC_G + i - 1], &gbase, NULL); } - /* Multiply gbase by 16. */ - for (i = 0; i < 4; i++) { + /* Multiply gbase by PREC_G. */ + for (i = 0; i < ECMULT_GEN_PREC_B; i++) { secp256k1_gej_double_var(&gbase, &gbase, NULL); } /* Multiply numbase by 2. */ secp256k1_gej_double_var(&numsbase, &numsbase, NULL); - if (j == 62) { + if (j == ECMULT_GEN_PREC_N - 2) { /* In the last iteration, numsbase is (1 - 2^j) * nums instead. */ secp256k1_gej_neg(&numsbase, &numsbase); secp256k1_gej_add_var(&numsbase, &numsbase, &nums_gej, NULL); } } - secp256k1_ge_set_all_gej_var(prec, precj, 1024); + secp256k1_ge_set_all_gej_var(prec, precj, ECMULT_GEN_PREC_N * ECMULT_GEN_PREC_G); } - for (j = 0; j < 64; j++) { - for (i = 0; i < 16; i++) { - secp256k1_ge_to_storage(&(*ctx->prec)[j][i], &prec[j*16 + i]); + for (j = 0; j < ECMULT_GEN_PREC_N; j++) { + for (i = 0; i < ECMULT_GEN_PREC_G; i++) { + secp256k1_ge_to_storage(&(*ctx->prec)[j][i], &prec[j*ECMULT_GEN_PREC_G + i]); } } #else - (void)cb; - ctx->prec = (secp256k1_ge_storage (*)[64][16])secp256k1_ecmult_static_context; + (void)prealloc; + ctx->prec = (secp256k1_ge_storage (*)[ECMULT_GEN_PREC_N][ECMULT_GEN_PREC_G])secp256k1_ecmult_static_context; #endif secp256k1_ecmult_gen_blind(ctx, NULL); } @@ -95,27 +105,18 @@ static int secp256k1_ecmult_gen_context_is_built(const secp256k1_ecmult_gen_cont return ctx->prec != NULL; } -static void secp256k1_ecmult_gen_context_clone(secp256k1_ecmult_gen_context *dst, - const secp256k1_ecmult_gen_context *src, const secp256k1_callback* cb) { - if (src->prec == NULL) { - dst->prec = NULL; - } else { +static void secp256k1_ecmult_gen_context_finalize_memcpy(secp256k1_ecmult_gen_context *dst, const secp256k1_ecmult_gen_context *src) { #ifndef USE_ECMULT_STATIC_PRECOMPUTATION - dst->prec = (secp256k1_ge_storage (*)[64][16])checked_malloc(cb, sizeof(*dst->prec)); - memcpy(dst->prec, src->prec, sizeof(*dst->prec)); + if (src->prec != NULL) { + /* We cast to void* first to suppress a -Wcast-align warning. */ + dst->prec = (secp256k1_ge_storage (*)[ECMULT_GEN_PREC_N][ECMULT_GEN_PREC_G])(void*)((unsigned char*)dst + ((unsigned char*)src->prec - (unsigned char*)src)); + } #else - (void)cb; - dst->prec = src->prec; + (void)dst, (void)src; #endif - dst->initial = src->initial; - dst->blind = src->blind; - } } static void secp256k1_ecmult_gen_context_clear(secp256k1_ecmult_gen_context *ctx) { -#ifndef USE_ECMULT_STATIC_PRECOMPUTATION - free(ctx->prec); -#endif secp256k1_scalar_clear(&ctx->blind); secp256k1_gej_clear(&ctx->initial); ctx->prec = NULL; @@ -132,9 +133,9 @@ static void secp256k1_ecmult_gen(const secp256k1_ecmult_gen_context *ctx, secp25 /* Blind scalar/point multiplication by computing (n-b)G + bG instead of nG. */ secp256k1_scalar_add(&gnb, gn, &ctx->blind); add.infinity = 0; - for (j = 0; j < 64; j++) { - bits = secp256k1_scalar_get_bits(&gnb, j * 4, 4); - for (i = 0; i < 16; i++) { + for (j = 0; j < ECMULT_GEN_PREC_N; j++) { + bits = secp256k1_scalar_get_bits(&gnb, j * ECMULT_GEN_PREC_B, ECMULT_GEN_PREC_B); + for (i = 0; i < ECMULT_GEN_PREC_G; i++) { /** This uses a conditional move to avoid any secret data in array indexes. * _Any_ use of secret indexes has been demonstrated to result in timing * sidechannels, even when the cache-line access patterns are uniform. @@ -162,7 +163,7 @@ static void secp256k1_ecmult_gen_blind(secp256k1_ecmult_gen_context *ctx, const secp256k1_fe s; unsigned char nonce32[32]; secp256k1_rfc6979_hmac_sha256 rng; - int retry; + int overflow; unsigned char keydata[64] = {0}; if (seed32 == NULL) { /* When seed is NULL, reset the initial point and blinding value. */ @@ -182,21 +183,18 @@ static void secp256k1_ecmult_gen_blind(secp256k1_ecmult_gen_context *ctx, const } secp256k1_rfc6979_hmac_sha256_initialize(&rng, keydata, seed32 ? 64 : 32); memset(keydata, 0, sizeof(keydata)); - /* Retry for out of range results to achieve uniformity. */ - do { - secp256k1_rfc6979_hmac_sha256_generate(&rng, nonce32, 32); - retry = !secp256k1_fe_set_b32(&s, nonce32); - retry |= secp256k1_fe_is_zero(&s); - } while (retry); /* This branch true is cryptographically unreachable. Requires sha256_hmac output > Fp. */ + /* Accept unobservably small non-uniformity. */ + secp256k1_rfc6979_hmac_sha256_generate(&rng, nonce32, 32); + overflow = !secp256k1_fe_set_b32(&s, nonce32); + overflow |= secp256k1_fe_is_zero(&s); + secp256k1_fe_cmov(&s, &secp256k1_fe_one, overflow); /* Randomize the projection to defend against multiplier sidechannels. */ secp256k1_gej_rescale(&ctx->initial, &s); secp256k1_fe_clear(&s); - do { - secp256k1_rfc6979_hmac_sha256_generate(&rng, nonce32, 32); - secp256k1_scalar_set_b32(&b, nonce32, &retry); - /* A blinding value of 0 works, but would undermine the projection hardening. */ - retry |= secp256k1_scalar_is_zero(&b); - } while (retry); /* This branch true is cryptographically unreachable. Requires sha256_hmac output > order. */ + secp256k1_rfc6979_hmac_sha256_generate(&rng, nonce32, 32); + secp256k1_scalar_set_b32(&b, nonce32, NULL); + /* A blinding value of 0 works, but would undermine the projection hardening. */ + secp256k1_scalar_cmov(&b, &secp256k1_scalar_one, secp256k1_scalar_is_zero(&b)); secp256k1_rfc6979_hmac_sha256_finalize(&rng); memset(nonce32, 0, 32); secp256k1_ecmult_gen(ctx, &gb, &b); diff --git a/src/secp256k1/src/ecmult_impl.h b/src/secp256k1/src/ecmult_impl.h index 1986914a4f..f03fa9469d 100644 --- a/src/secp256k1/src/ecmult_impl.h +++ b/src/secp256k1/src/ecmult_impl.h @@ -10,6 +10,7 @@ #include <string.h> #include <stdint.h> +#include "util.h" #include "group.h" #include "scalar.h" #include "ecmult.h" @@ -30,16 +31,32 @@ # endif #else /* optimal for 128-bit and 256-bit exponents. */ -#define WINDOW_A 5 -/** larger numbers may result in slightly better performance, at the cost of - exponentially larger precomputed tables. */ -#ifdef USE_ENDOMORPHISM -/** Two tables for window size 15: 1.375 MiB. */ -#define WINDOW_G 15 -#else -/** One table for window size 16: 1.375 MiB. */ -#define WINDOW_G 16 +# define WINDOW_A 5 +/** Larger values for ECMULT_WINDOW_SIZE result in possibly better + * performance at the cost of an exponentially larger precomputed + * table. The exact table size is + * (1 << (WINDOW_G - 2)) * sizeof(secp256k1_ge_storage) bytes, + * where sizeof(secp256k1_ge_storage) is typically 64 bytes but can + * be larger due to platform-specific padding and alignment. + * If the endomorphism optimization is enabled (USE_ENDOMORMPHSIM) + * two tables of this size are used instead of only one. + */ +# define WINDOW_G ECMULT_WINDOW_SIZE #endif + +/* Noone will ever need more than a window size of 24. The code might + * be correct for larger values of ECMULT_WINDOW_SIZE but this is not + * not tested. + * + * The following limitations are known, and there are probably more: + * If WINDOW_G > 27 and size_t has 32 bits, then the code is incorrect + * because the size of the memory object that we allocate (in bytes) + * will not fit in a size_t. + * If WINDOW_G > 31 and int has 32 bits, then the code is incorrect + * because certain expressions will overflow. + */ +#if ECMULT_WINDOW_SIZE < 2 || ECMULT_WINDOW_SIZE > 24 +# error Set ECMULT_WINDOW_SIZE to an integer in range [2..24]. #endif #ifdef USE_ENDOMORPHISM @@ -121,7 +138,7 @@ static void secp256k1_ecmult_odd_multiples_table(int n, secp256k1_gej *prej, sec * It only operates on tables sized for WINDOW_A wnaf multiples. * - secp256k1_ecmult_odd_multiples_table_storage_var, which converts its * resulting point set to actually affine points, and stores those in pre. - * It operates on tables of any size, but uses heap-allocated temporaries. + * It operates on tables of any size. * * To compute a*P + b*G, we compute a table for P using the first function, * and for G using the second (which requires an inverse, but it only needs to @@ -294,6 +311,13 @@ static void secp256k1_ecmult_odd_multiples_table_storage_var(const int n, secp25 } \ } while(0) +static const size_t SECP256K1_ECMULT_CONTEXT_PREALLOCATED_SIZE = + ROUND_TO_ALIGN(sizeof((*((secp256k1_ecmult_context*) NULL)->pre_g)[0]) * ECMULT_TABLE_SIZE(WINDOW_G)) +#ifdef USE_ENDOMORPHISM + + ROUND_TO_ALIGN(sizeof((*((secp256k1_ecmult_context*) NULL)->pre_g_128)[0]) * ECMULT_TABLE_SIZE(WINDOW_G)) +#endif + ; + static void secp256k1_ecmult_context_init(secp256k1_ecmult_context *ctx) { ctx->pre_g = NULL; #ifdef USE_ENDOMORPHISM @@ -301,8 +325,10 @@ static void secp256k1_ecmult_context_init(secp256k1_ecmult_context *ctx) { #endif } -static void secp256k1_ecmult_context_build(secp256k1_ecmult_context *ctx, const secp256k1_callback *cb) { +static void secp256k1_ecmult_context_build(secp256k1_ecmult_context *ctx, void **prealloc) { secp256k1_gej gj; + void* const base = *prealloc; + size_t const prealloc_size = SECP256K1_ECMULT_CONTEXT_PREALLOCATED_SIZE; if (ctx->pre_g != NULL) { return; @@ -311,7 +337,12 @@ static void secp256k1_ecmult_context_build(secp256k1_ecmult_context *ctx, const /* get the generator */ secp256k1_gej_set_ge(&gj, &secp256k1_ge_const_g); - ctx->pre_g = (secp256k1_ge_storage (*)[])checked_malloc(cb, sizeof((*ctx->pre_g)[0]) * ECMULT_TABLE_SIZE(WINDOW_G)); + { + size_t size = sizeof((*ctx->pre_g)[0]) * ((size_t)ECMULT_TABLE_SIZE(WINDOW_G)); + /* check for overflow */ + VERIFY_CHECK(size / sizeof((*ctx->pre_g)[0]) == ((size_t)ECMULT_TABLE_SIZE(WINDOW_G))); + ctx->pre_g = (secp256k1_ge_storage (*)[])manual_alloc(prealloc, sizeof((*ctx->pre_g)[0]) * ECMULT_TABLE_SIZE(WINDOW_G), base, prealloc_size); + } /* precompute the tables with odd multiples */ secp256k1_ecmult_odd_multiples_table_storage_var(ECMULT_TABLE_SIZE(WINDOW_G), *ctx->pre_g, &gj); @@ -321,7 +352,10 @@ static void secp256k1_ecmult_context_build(secp256k1_ecmult_context *ctx, const secp256k1_gej g_128j; int i; - ctx->pre_g_128 = (secp256k1_ge_storage (*)[])checked_malloc(cb, sizeof((*ctx->pre_g_128)[0]) * ECMULT_TABLE_SIZE(WINDOW_G)); + size_t size = sizeof((*ctx->pre_g_128)[0]) * ((size_t) ECMULT_TABLE_SIZE(WINDOW_G)); + /* check for overflow */ + VERIFY_CHECK(size / sizeof((*ctx->pre_g_128)[0]) == ((size_t)ECMULT_TABLE_SIZE(WINDOW_G))); + ctx->pre_g_128 = (secp256k1_ge_storage (*)[])manual_alloc(prealloc, sizeof((*ctx->pre_g_128)[0]) * ECMULT_TABLE_SIZE(WINDOW_G), base, prealloc_size); /* calculate 2^128*generator */ g_128j = gj; @@ -333,22 +367,14 @@ static void secp256k1_ecmult_context_build(secp256k1_ecmult_context *ctx, const #endif } -static void secp256k1_ecmult_context_clone(secp256k1_ecmult_context *dst, - const secp256k1_ecmult_context *src, const secp256k1_callback *cb) { - if (src->pre_g == NULL) { - dst->pre_g = NULL; - } else { - size_t size = sizeof((*dst->pre_g)[0]) * ECMULT_TABLE_SIZE(WINDOW_G); - dst->pre_g = (secp256k1_ge_storage (*)[])checked_malloc(cb, size); - memcpy(dst->pre_g, src->pre_g, size); +static void secp256k1_ecmult_context_finalize_memcpy(secp256k1_ecmult_context *dst, const secp256k1_ecmult_context *src) { + if (src->pre_g != NULL) { + /* We cast to void* first to suppress a -Wcast-align warning. */ + dst->pre_g = (secp256k1_ge_storage (*)[])(void*)((unsigned char*)dst + ((unsigned char*)(src->pre_g) - (unsigned char*)src)); } #ifdef USE_ENDOMORPHISM - if (src->pre_g_128 == NULL) { - dst->pre_g_128 = NULL; - } else { - size_t size = sizeof((*dst->pre_g_128)[0]) * ECMULT_TABLE_SIZE(WINDOW_G); - dst->pre_g_128 = (secp256k1_ge_storage (*)[])checked_malloc(cb, size); - memcpy(dst->pre_g_128, src->pre_g_128, size); + if (src->pre_g_128 != NULL) { + dst->pre_g_128 = (secp256k1_ge_storage (*)[])(void*)((unsigned char*)dst + ((unsigned char*)(src->pre_g_128) - (unsigned char*)src)); } #endif } @@ -358,10 +384,6 @@ static int secp256k1_ecmult_context_is_built(const secp256k1_ecmult_context *ctx } static void secp256k1_ecmult_context_clear(secp256k1_ecmult_context *ctx) { - free(ctx->pre_g); -#ifdef USE_ENDOMORPHISM - free(ctx->pre_g_128); -#endif secp256k1_ecmult_context_init(ctx); } @@ -373,7 +395,7 @@ static void secp256k1_ecmult_context_clear(secp256k1_ecmult_context *ctx) { * than the number of bits in the (absolute value) of the input. */ static int secp256k1_ecmult_wnaf(int *wnaf, int len, const secp256k1_scalar *a, int w) { - secp256k1_scalar s = *a; + secp256k1_scalar s; int last_set_bit = -1; int bit = 0; int sign = 1; @@ -386,6 +408,7 @@ static int secp256k1_ecmult_wnaf(int *wnaf, int len, const secp256k1_scalar *a, memset(wnaf, 0, len * sizeof(wnaf[0])); + s = *a; if (secp256k1_scalar_get_bits(&s, 255, 1)) { secp256k1_scalar_negate(&s, &s); sign = -1; @@ -418,7 +441,7 @@ static int secp256k1_ecmult_wnaf(int *wnaf, int len, const secp256k1_scalar *a, CHECK(carry == 0); while (bit < 256) { CHECK(secp256k1_scalar_get_bits(&s, bit++, 1) == 0); - } + } #endif return last_set_bit + 1; } @@ -626,52 +649,55 @@ static size_t secp256k1_strauss_scratch_size(size_t n_points) { return n_points*point_size; } -static int secp256k1_ecmult_strauss_batch(const secp256k1_ecmult_context *ctx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n_points, size_t cb_offset) { +static int secp256k1_ecmult_strauss_batch(const secp256k1_callback* error_callback, const secp256k1_ecmult_context *ctx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n_points, size_t cb_offset) { secp256k1_gej* points; secp256k1_scalar* scalars; struct secp256k1_strauss_state state; size_t i; + const size_t scratch_checkpoint = secp256k1_scratch_checkpoint(error_callback, scratch); secp256k1_gej_set_infinity(r); if (inp_g_sc == NULL && n_points == 0) { return 1; } - if (!secp256k1_scratch_allocate_frame(scratch, secp256k1_strauss_scratch_size(n_points), STRAUSS_SCRATCH_OBJECTS)) { - return 0; - } - points = (secp256k1_gej*)secp256k1_scratch_alloc(scratch, n_points * sizeof(secp256k1_gej)); - scalars = (secp256k1_scalar*)secp256k1_scratch_alloc(scratch, n_points * sizeof(secp256k1_scalar)); - state.prej = (secp256k1_gej*)secp256k1_scratch_alloc(scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_gej)); - state.zr = (secp256k1_fe*)secp256k1_scratch_alloc(scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_fe)); + points = (secp256k1_gej*)secp256k1_scratch_alloc(error_callback, scratch, n_points * sizeof(secp256k1_gej)); + scalars = (secp256k1_scalar*)secp256k1_scratch_alloc(error_callback, scratch, n_points * sizeof(secp256k1_scalar)); + state.prej = (secp256k1_gej*)secp256k1_scratch_alloc(error_callback, scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_gej)); + state.zr = (secp256k1_fe*)secp256k1_scratch_alloc(error_callback, scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_fe)); #ifdef USE_ENDOMORPHISM - state.pre_a = (secp256k1_ge*)secp256k1_scratch_alloc(scratch, n_points * 2 * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_ge)); + state.pre_a = (secp256k1_ge*)secp256k1_scratch_alloc(error_callback, scratch, n_points * 2 * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_ge)); state.pre_a_lam = state.pre_a + n_points * ECMULT_TABLE_SIZE(WINDOW_A); #else - state.pre_a = (secp256k1_ge*)secp256k1_scratch_alloc(scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_ge)); + state.pre_a = (secp256k1_ge*)secp256k1_scratch_alloc(error_callback, scratch, n_points * ECMULT_TABLE_SIZE(WINDOW_A) * sizeof(secp256k1_ge)); #endif - state.ps = (struct secp256k1_strauss_point_state*)secp256k1_scratch_alloc(scratch, n_points * sizeof(struct secp256k1_strauss_point_state)); + state.ps = (struct secp256k1_strauss_point_state*)secp256k1_scratch_alloc(error_callback, scratch, n_points * sizeof(struct secp256k1_strauss_point_state)); + + if (points == NULL || scalars == NULL || state.prej == NULL || state.zr == NULL || state.pre_a == NULL) { + secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint); + return 0; + } for (i = 0; i < n_points; i++) { secp256k1_ge point; if (!cb(&scalars[i], &point, i+cb_offset, cbdata)) { - secp256k1_scratch_deallocate_frame(scratch); + secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint); return 0; } secp256k1_gej_set_ge(&points[i], &point); } secp256k1_ecmult_strauss_wnaf(ctx, &state, r, n_points, points, scalars, inp_g_sc); - secp256k1_scratch_deallocate_frame(scratch); + secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint); return 1; } /* Wrapper for secp256k1_ecmult_multi_func interface */ -static int secp256k1_ecmult_strauss_batch_single(const secp256k1_ecmult_context *actx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n) { - return secp256k1_ecmult_strauss_batch(actx, scratch, r, inp_g_sc, cb, cbdata, n, 0); +static int secp256k1_ecmult_strauss_batch_single(const secp256k1_callback* error_callback, const secp256k1_ecmult_context *actx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n) { + return secp256k1_ecmult_strauss_batch(error_callback, actx, scratch, r, inp_g_sc, cb, cbdata, n, 0); } -static size_t secp256k1_strauss_max_points(secp256k1_scratch *scratch) { - return secp256k1_scratch_max_allocation(scratch, STRAUSS_SCRATCH_OBJECTS) / secp256k1_strauss_scratch_size(1); +static size_t secp256k1_strauss_max_points(const secp256k1_callback* error_callback, secp256k1_scratch *scratch) { + return secp256k1_scratch_max_allocation(error_callback, scratch, STRAUSS_SCRATCH_OBJECTS) / secp256k1_strauss_scratch_size(1); } /** Convert a number to WNAF notation. @@ -963,7 +989,8 @@ static size_t secp256k1_pippenger_scratch_size(size_t n_points, int bucket_windo return (sizeof(secp256k1_gej) << bucket_window) + sizeof(struct secp256k1_pippenger_state) + entries * entry_size; } -static int secp256k1_ecmult_pippenger_batch(const secp256k1_ecmult_context *ctx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n_points, size_t cb_offset) { +static int secp256k1_ecmult_pippenger_batch(const secp256k1_callback* error_callback, const secp256k1_ecmult_context *ctx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n_points, size_t cb_offset) { + const size_t scratch_checkpoint = secp256k1_scratch_checkpoint(error_callback, scratch); /* Use 2(n+1) with the endomorphism, n+1 without, when calculating batch * sizes. The reason for +1 is that we add the G scalar to the list of * other scalars. */ @@ -988,15 +1015,21 @@ static int secp256k1_ecmult_pippenger_batch(const secp256k1_ecmult_context *ctx, } bucket_window = secp256k1_pippenger_bucket_window(n_points); - if (!secp256k1_scratch_allocate_frame(scratch, secp256k1_pippenger_scratch_size(n_points, bucket_window), PIPPENGER_SCRATCH_OBJECTS)) { + points = (secp256k1_ge *) secp256k1_scratch_alloc(error_callback, scratch, entries * sizeof(*points)); + scalars = (secp256k1_scalar *) secp256k1_scratch_alloc(error_callback, scratch, entries * sizeof(*scalars)); + state_space = (struct secp256k1_pippenger_state *) secp256k1_scratch_alloc(error_callback, scratch, sizeof(*state_space)); + if (points == NULL || scalars == NULL || state_space == NULL) { + secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint); + return 0; + } + + state_space->ps = (struct secp256k1_pippenger_point_state *) secp256k1_scratch_alloc(error_callback, scratch, entries * sizeof(*state_space->ps)); + state_space->wnaf_na = (int *) secp256k1_scratch_alloc(error_callback, scratch, entries*(WNAF_SIZE(bucket_window+1)) * sizeof(int)); + buckets = (secp256k1_gej *) secp256k1_scratch_alloc(error_callback, scratch, (1<<bucket_window) * sizeof(*buckets)); + if (state_space->ps == NULL || state_space->wnaf_na == NULL || buckets == NULL) { + secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint); return 0; } - points = (secp256k1_ge *) secp256k1_scratch_alloc(scratch, entries * sizeof(*points)); - scalars = (secp256k1_scalar *) secp256k1_scratch_alloc(scratch, entries * sizeof(*scalars)); - state_space = (struct secp256k1_pippenger_state *) secp256k1_scratch_alloc(scratch, sizeof(*state_space)); - state_space->ps = (struct secp256k1_pippenger_point_state *) secp256k1_scratch_alloc(scratch, entries * sizeof(*state_space->ps)); - state_space->wnaf_na = (int *) secp256k1_scratch_alloc(scratch, entries*(WNAF_SIZE(bucket_window+1)) * sizeof(int)); - buckets = (secp256k1_gej *) secp256k1_scratch_alloc(scratch, sizeof(*buckets) << bucket_window); if (inp_g_sc != NULL) { scalars[0] = *inp_g_sc; @@ -1010,7 +1043,7 @@ static int secp256k1_ecmult_pippenger_batch(const secp256k1_ecmult_context *ctx, while (point_idx < n_points) { if (!cb(&scalars[idx], &points[idx], point_idx + cb_offset, cbdata)) { - secp256k1_scratch_deallocate_frame(scratch); + secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint); return 0; } idx++; @@ -1034,13 +1067,13 @@ static int secp256k1_ecmult_pippenger_batch(const secp256k1_ecmult_context *ctx, for(i = 0; i < 1<<bucket_window; i++) { secp256k1_gej_clear(&buckets[i]); } - secp256k1_scratch_deallocate_frame(scratch); + secp256k1_scratch_apply_checkpoint(error_callback, scratch, scratch_checkpoint); return 1; } /* Wrapper for secp256k1_ecmult_multi_func interface */ -static int secp256k1_ecmult_pippenger_batch_single(const secp256k1_ecmult_context *actx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n) { - return secp256k1_ecmult_pippenger_batch(actx, scratch, r, inp_g_sc, cb, cbdata, n, 0); +static int secp256k1_ecmult_pippenger_batch_single(const secp256k1_callback* error_callback, const secp256k1_ecmult_context *actx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n) { + return secp256k1_ecmult_pippenger_batch(error_callback, actx, scratch, r, inp_g_sc, cb, cbdata, n, 0); } /** @@ -1048,8 +1081,8 @@ static int secp256k1_ecmult_pippenger_batch_single(const secp256k1_ecmult_contex * a given scratch space. The function ensures that fewer points may also be * used. */ -static size_t secp256k1_pippenger_max_points(secp256k1_scratch *scratch) { - size_t max_alloc = secp256k1_scratch_max_allocation(scratch, PIPPENGER_SCRATCH_OBJECTS); +static size_t secp256k1_pippenger_max_points(const secp256k1_callback* error_callback, secp256k1_scratch *scratch) { + size_t max_alloc = secp256k1_scratch_max_allocation(error_callback, scratch, PIPPENGER_SCRATCH_OBJECTS); int bucket_window; size_t res = 0; @@ -1131,11 +1164,11 @@ static int secp256k1_ecmult_multi_batch_size_helper(size_t *n_batches, size_t *n return 1; } -typedef int (*secp256k1_ecmult_multi_func)(const secp256k1_ecmult_context*, secp256k1_scratch*, secp256k1_gej*, const secp256k1_scalar*, secp256k1_ecmult_multi_callback cb, void*, size_t); -static int secp256k1_ecmult_multi_var(const secp256k1_ecmult_context *ctx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n) { +typedef int (*secp256k1_ecmult_multi_func)(const secp256k1_callback* error_callback, const secp256k1_ecmult_context*, secp256k1_scratch*, secp256k1_gej*, const secp256k1_scalar*, secp256k1_ecmult_multi_callback cb, void*, size_t); +static int secp256k1_ecmult_multi_var(const secp256k1_callback* error_callback, const secp256k1_ecmult_context *ctx, secp256k1_scratch *scratch, secp256k1_gej *r, const secp256k1_scalar *inp_g_sc, secp256k1_ecmult_multi_callback cb, void *cbdata, size_t n) { size_t i; - int (*f)(const secp256k1_ecmult_context*, secp256k1_scratch*, secp256k1_gej*, const secp256k1_scalar*, secp256k1_ecmult_multi_callback cb, void*, size_t, size_t); + int (*f)(const secp256k1_callback* error_callback, const secp256k1_ecmult_context*, secp256k1_scratch*, secp256k1_gej*, const secp256k1_scalar*, secp256k1_ecmult_multi_callback cb, void*, size_t, size_t); size_t n_batches; size_t n_batch_points; @@ -1152,16 +1185,18 @@ static int secp256k1_ecmult_multi_var(const secp256k1_ecmult_context *ctx, secp2 return secp256k1_ecmult_multi_simple_var(ctx, r, inp_g_sc, cb, cbdata, n); } - /* Compute the batch sizes for pippenger given a scratch space. If it's greater than a threshold - * use pippenger. Otherwise use strauss */ - if (!secp256k1_ecmult_multi_batch_size_helper(&n_batches, &n_batch_points, secp256k1_pippenger_max_points(scratch), n)) { - return 0; + /* Compute the batch sizes for Pippenger's algorithm given a scratch space. If it's greater than + * a threshold use Pippenger's algorithm. Otherwise use Strauss' algorithm. + * As a first step check if there's enough space for Pippenger's algo (which requires less space + * than Strauss' algo) and if not, use the simple algorithm. */ + if (!secp256k1_ecmult_multi_batch_size_helper(&n_batches, &n_batch_points, secp256k1_pippenger_max_points(error_callback, scratch), n)) { + return secp256k1_ecmult_multi_simple_var(ctx, r, inp_g_sc, cb, cbdata, n); } if (n_batch_points >= ECMULT_PIPPENGER_THRESHOLD) { f = secp256k1_ecmult_pippenger_batch; } else { - if (!secp256k1_ecmult_multi_batch_size_helper(&n_batches, &n_batch_points, secp256k1_strauss_max_points(scratch), n)) { - return 0; + if (!secp256k1_ecmult_multi_batch_size_helper(&n_batches, &n_batch_points, secp256k1_strauss_max_points(error_callback, scratch), n)) { + return secp256k1_ecmult_multi_simple_var(ctx, r, inp_g_sc, cb, cbdata, n); } f = secp256k1_ecmult_strauss_batch; } @@ -1169,7 +1204,7 @@ static int secp256k1_ecmult_multi_var(const secp256k1_ecmult_context *ctx, secp2 size_t nbp = n < n_batch_points ? n : n_batch_points; size_t offset = n_batch_points*i; secp256k1_gej tmp; - if (!f(ctx, scratch, &tmp, i == 0 ? inp_g_sc : NULL, cb, cbdata, nbp, offset)) { + if (!f(error_callback, ctx, scratch, &tmp, i == 0 ? inp_g_sc : NULL, cb, cbdata, nbp, offset)) { return 0; } secp256k1_gej_add_var(r, r, &tmp, NULL); diff --git a/src/secp256k1/src/field.h b/src/secp256k1/src/field.h index bb6692ad57..7993a1f11e 100644 --- a/src/secp256k1/src/field.h +++ b/src/secp256k1/src/field.h @@ -32,10 +32,12 @@ #include "util.h" -/** Normalize a field element. */ +/** Normalize a field element. This brings the field element to a canonical representation, reduces + * its magnitude to 1, and reduces it modulo field size `p`. + */ static void secp256k1_fe_normalize(secp256k1_fe *r); -/** Weakly normalize a field element: reduce it magnitude to 1, but don't fully normalize. */ +/** Weakly normalize a field element: reduce its magnitude to 1, but don't fully normalize. */ static void secp256k1_fe_normalize_weak(secp256k1_fe *r); /** Normalize a field element, without constant-time guarantee. */ @@ -123,10 +125,10 @@ static void secp256k1_fe_to_storage(secp256k1_fe_storage *r, const secp256k1_fe /** Convert a field element back from the storage type. */ static void secp256k1_fe_from_storage(secp256k1_fe *r, const secp256k1_fe_storage *a); -/** If flag is true, set *r equal to *a; otherwise leave it. Constant-time. */ +/** If flag is true, set *r equal to *a; otherwise leave it. Constant-time. Both *r and *a must be initialized.*/ static void secp256k1_fe_storage_cmov(secp256k1_fe_storage *r, const secp256k1_fe_storage *a, int flag); -/** If flag is true, set *r equal to *a; otherwise leave it. Constant-time. */ +/** If flag is true, set *r equal to *a; otherwise leave it. Constant-time. Both *r and *a must be initialized.*/ static void secp256k1_fe_cmov(secp256k1_fe *r, const secp256k1_fe *a, int flag); #endif /* SECP256K1_FIELD_H */ diff --git a/src/secp256k1/src/field_10x26_impl.h b/src/secp256k1/src/field_10x26_impl.h index 4ae4fdcec8..651500ee8e 100644 --- a/src/secp256k1/src/field_10x26_impl.h +++ b/src/secp256k1/src/field_10x26_impl.h @@ -320,6 +320,7 @@ static int secp256k1_fe_cmp_var(const secp256k1_fe *a, const secp256k1_fe *b) { } static int secp256k1_fe_set_b32(secp256k1_fe *r, const unsigned char *a) { + int ret; r->n[0] = (uint32_t)a[31] | ((uint32_t)a[30] << 8) | ((uint32_t)a[29] << 16) | ((uint32_t)(a[28] & 0x3) << 24); r->n[1] = (uint32_t)((a[28] >> 2) & 0x3f) | ((uint32_t)a[27] << 6) | ((uint32_t)a[26] << 14) | ((uint32_t)(a[25] & 0xf) << 22); r->n[2] = (uint32_t)((a[25] >> 4) & 0xf) | ((uint32_t)a[24] << 4) | ((uint32_t)a[23] << 12) | ((uint32_t)(a[22] & 0x3f) << 20); @@ -331,15 +332,17 @@ static int secp256k1_fe_set_b32(secp256k1_fe *r, const unsigned char *a) { r->n[8] = (uint32_t)a[5] | ((uint32_t)a[4] << 8) | ((uint32_t)a[3] << 16) | ((uint32_t)(a[2] & 0x3) << 24); r->n[9] = (uint32_t)((a[2] >> 2) & 0x3f) | ((uint32_t)a[1] << 6) | ((uint32_t)a[0] << 14); - if (r->n[9] == 0x3FFFFFUL && (r->n[8] & r->n[7] & r->n[6] & r->n[5] & r->n[4] & r->n[3] & r->n[2]) == 0x3FFFFFFUL && (r->n[1] + 0x40UL + ((r->n[0] + 0x3D1UL) >> 26)) > 0x3FFFFFFUL) { - return 0; - } + ret = !((r->n[9] == 0x3FFFFFUL) & ((r->n[8] & r->n[7] & r->n[6] & r->n[5] & r->n[4] & r->n[3] & r->n[2]) == 0x3FFFFFFUL) & ((r->n[1] + 0x40UL + ((r->n[0] + 0x3D1UL) >> 26)) > 0x3FFFFFFUL)); #ifdef VERIFY r->magnitude = 1; - r->normalized = 1; - secp256k1_fe_verify(r); + if (ret) { + r->normalized = 1; + secp256k1_fe_verify(r); + } else { + r->normalized = 0; + } #endif - return 1; + return ret; } /** Convert a field element to a 32-byte big endian value. Requires the input to be normalized */ @@ -1094,6 +1097,7 @@ static void secp256k1_fe_sqr(secp256k1_fe *r, const secp256k1_fe *a) { static SECP256K1_INLINE void secp256k1_fe_cmov(secp256k1_fe *r, const secp256k1_fe *a, int flag) { uint32_t mask0, mask1; + VG_CHECK_VERIFY(r->n, sizeof(r->n)); mask0 = flag + ~((uint32_t)0); mask1 = ~mask0; r->n[0] = (r->n[0] & mask0) | (a->n[0] & mask1); @@ -1107,15 +1111,16 @@ static SECP256K1_INLINE void secp256k1_fe_cmov(secp256k1_fe *r, const secp256k1_ r->n[8] = (r->n[8] & mask0) | (a->n[8] & mask1); r->n[9] = (r->n[9] & mask0) | (a->n[9] & mask1); #ifdef VERIFY - if (a->magnitude > r->magnitude) { + if (flag) { r->magnitude = a->magnitude; + r->normalized = a->normalized; } - r->normalized &= a->normalized; #endif } static SECP256K1_INLINE void secp256k1_fe_storage_cmov(secp256k1_fe_storage *r, const secp256k1_fe_storage *a, int flag) { uint32_t mask0, mask1; + VG_CHECK_VERIFY(r->n, sizeof(r->n)); mask0 = flag + ~((uint32_t)0); mask1 = ~mask0; r->n[0] = (r->n[0] & mask0) | (a->n[0] & mask1); diff --git a/src/secp256k1/src/field_5x52_impl.h b/src/secp256k1/src/field_5x52_impl.h index f4263320d5..71a38f915b 100644 --- a/src/secp256k1/src/field_5x52_impl.h +++ b/src/secp256k1/src/field_5x52_impl.h @@ -283,6 +283,7 @@ static int secp256k1_fe_cmp_var(const secp256k1_fe *a, const secp256k1_fe *b) { } static int secp256k1_fe_set_b32(secp256k1_fe *r, const unsigned char *a) { + int ret; r->n[0] = (uint64_t)a[31] | ((uint64_t)a[30] << 8) | ((uint64_t)a[29] << 16) @@ -317,15 +318,17 @@ static int secp256k1_fe_set_b32(secp256k1_fe *r, const unsigned char *a) { | ((uint64_t)a[2] << 24) | ((uint64_t)a[1] << 32) | ((uint64_t)a[0] << 40); - if (r->n[4] == 0x0FFFFFFFFFFFFULL && (r->n[3] & r->n[2] & r->n[1]) == 0xFFFFFFFFFFFFFULL && r->n[0] >= 0xFFFFEFFFFFC2FULL) { - return 0; - } + ret = !((r->n[4] == 0x0FFFFFFFFFFFFULL) & ((r->n[3] & r->n[2] & r->n[1]) == 0xFFFFFFFFFFFFFULL) & (r->n[0] >= 0xFFFFEFFFFFC2FULL)); #ifdef VERIFY r->magnitude = 1; - r->normalized = 1; - secp256k1_fe_verify(r); + if (ret) { + r->normalized = 1; + secp256k1_fe_verify(r); + } else { + r->normalized = 0; + } #endif - return 1; + return ret; } /** Convert a field element to a 32-byte big endian value. Requires the input to be normalized */ @@ -446,6 +449,7 @@ static void secp256k1_fe_sqr(secp256k1_fe *r, const secp256k1_fe *a) { static SECP256K1_INLINE void secp256k1_fe_cmov(secp256k1_fe *r, const secp256k1_fe *a, int flag) { uint64_t mask0, mask1; + VG_CHECK_VERIFY(r->n, sizeof(r->n)); mask0 = flag + ~((uint64_t)0); mask1 = ~mask0; r->n[0] = (r->n[0] & mask0) | (a->n[0] & mask1); @@ -454,15 +458,16 @@ static SECP256K1_INLINE void secp256k1_fe_cmov(secp256k1_fe *r, const secp256k1_ r->n[3] = (r->n[3] & mask0) | (a->n[3] & mask1); r->n[4] = (r->n[4] & mask0) | (a->n[4] & mask1); #ifdef VERIFY - if (a->magnitude > r->magnitude) { + if (flag) { r->magnitude = a->magnitude; + r->normalized = a->normalized; } - r->normalized &= a->normalized; #endif } static SECP256K1_INLINE void secp256k1_fe_storage_cmov(secp256k1_fe_storage *r, const secp256k1_fe_storage *a, int flag) { uint64_t mask0, mask1; + VG_CHECK_VERIFY(r->n, sizeof(r->n)); mask0 = flag + ~((uint64_t)0); mask1 = ~mask0; r->n[0] = (r->n[0] & mask0) | (a->n[0] & mask1); diff --git a/src/secp256k1/src/field_impl.h b/src/secp256k1/src/field_impl.h index 6070caccfe..485921a60e 100644 --- a/src/secp256k1/src/field_impl.h +++ b/src/secp256k1/src/field_impl.h @@ -315,4 +315,6 @@ static int secp256k1_fe_is_quad_var(const secp256k1_fe *a) { #endif } +static const secp256k1_fe secp256k1_fe_one = SECP256K1_FE_CONST(0, 0, 0, 0, 0, 0, 0, 1); + #endif /* SECP256K1_FIELD_IMPL_H */ diff --git a/src/secp256k1/src/gen_context.c b/src/secp256k1/src/gen_context.c index 87d296ebf0..539f574bfd 100644 --- a/src/secp256k1/src/gen_context.c +++ b/src/secp256k1/src/gen_context.c @@ -4,10 +4,16 @@ * file COPYING or http://www.opensource.org/licenses/mit-license.php.* **********************************************************************/ +// Autotools creates libsecp256k1-config.h, of which ECMULT_GEN_PREC_BITS is needed. +// ifndef guard so downstream users can define their own if they do not use autotools. +#if !defined(ECMULT_GEN_PREC_BITS) +#include "libsecp256k1-config.h" +#endif #define USE_BASIC_CONFIG 1 - #include "basic-config.h" + #include "include/secp256k1.h" +#include "util.h" #include "field_impl.h" #include "scalar_impl.h" #include "group_impl.h" @@ -26,6 +32,7 @@ static const secp256k1_callback default_error_callback = { int main(int argc, char **argv) { secp256k1_ecmult_gen_context ctx; + void *prealloc, *base; int inner; int outer; FILE* fp; @@ -38,26 +45,31 @@ int main(int argc, char **argv) { fprintf(stderr, "Could not open src/ecmult_static_context.h for writing!\n"); return -1; } - + fprintf(fp, "#ifndef _SECP256K1_ECMULT_STATIC_CONTEXT_\n"); fprintf(fp, "#define _SECP256K1_ECMULT_STATIC_CONTEXT_\n"); fprintf(fp, "#include \"src/group.h\"\n"); fprintf(fp, "#define SC SECP256K1_GE_STORAGE_CONST\n"); - fprintf(fp, "static const secp256k1_ge_storage secp256k1_ecmult_static_context[64][16] = {\n"); + fprintf(fp, "#if ECMULT_GEN_PREC_N != %d || ECMULT_GEN_PREC_G != %d\n", ECMULT_GEN_PREC_N, ECMULT_GEN_PREC_G); + fprintf(fp, " #error configuration mismatch, invalid ECMULT_GEN_PREC_N, ECMULT_GEN_PREC_G. Try deleting ecmult_static_context.h before the build.\n"); + fprintf(fp, "#endif\n"); + fprintf(fp, "static const secp256k1_ge_storage secp256k1_ecmult_static_context[ECMULT_GEN_PREC_N][ECMULT_GEN_PREC_G] = {\n"); + base = checked_malloc(&default_error_callback, SECP256K1_ECMULT_GEN_CONTEXT_PREALLOCATED_SIZE); + prealloc = base; secp256k1_ecmult_gen_context_init(&ctx); - secp256k1_ecmult_gen_context_build(&ctx, &default_error_callback); - for(outer = 0; outer != 64; outer++) { + secp256k1_ecmult_gen_context_build(&ctx, &prealloc); + for(outer = 0; outer != ECMULT_GEN_PREC_N; outer++) { fprintf(fp,"{\n"); - for(inner = 0; inner != 16; inner++) { + for(inner = 0; inner != ECMULT_GEN_PREC_G; inner++) { fprintf(fp," SC(%uu, %uu, %uu, %uu, %uu, %uu, %uu, %uu, %uu, %uu, %uu, %uu, %uu, %uu, %uu, %uu)", SECP256K1_GE_STORAGE_CONST_GET((*ctx.prec)[outer][inner])); - if (inner != 15) { + if (inner != ECMULT_GEN_PREC_G - 1) { fprintf(fp,",\n"); } else { fprintf(fp,"\n"); } } - if (outer != 63) { + if (outer != ECMULT_GEN_PREC_N - 1) { fprintf(fp,"},\n"); } else { fprintf(fp,"}\n"); @@ -65,10 +77,11 @@ int main(int argc, char **argv) { } fprintf(fp,"};\n"); secp256k1_ecmult_gen_context_clear(&ctx); - + free(base); + fprintf(fp, "#undef SC\n"); fprintf(fp, "#endif\n"); fclose(fp); - + return 0; } diff --git a/src/secp256k1/src/group.h b/src/secp256k1/src/group.h index 8e122ab429..863644f0f0 100644 --- a/src/secp256k1/src/group.h +++ b/src/secp256k1/src/group.h @@ -95,14 +95,13 @@ static int secp256k1_gej_is_infinity(const secp256k1_gej *a); /** Check whether a group element's y coordinate is a quadratic residue. */ static int secp256k1_gej_has_quad_y_var(const secp256k1_gej *a); -/** Set r equal to the double of a. If rzr is not-NULL, r->z = a->z * *rzr (where infinity means an implicit z = 0). - * a may not be zero. Constant time. */ -static void secp256k1_gej_double_nonzero(secp256k1_gej *r, const secp256k1_gej *a, secp256k1_fe *rzr); +/** Set r equal to the double of a, a cannot be infinity. Constant time. */ +static void secp256k1_gej_double_nonzero(secp256k1_gej *r, const secp256k1_gej *a); -/** Set r equal to the double of a. If rzr is not-NULL, r->z = a->z * *rzr (where infinity means an implicit z = 0). */ +/** Set r equal to the double of a. If rzr is not-NULL this sets *rzr such that r->z == a->z * *rzr (where infinity means an implicit z = 0). */ static void secp256k1_gej_double_var(secp256k1_gej *r, const secp256k1_gej *a, secp256k1_fe *rzr); -/** Set r equal to the sum of a and b. If rzr is non-NULL, r->z = a->z * *rzr (a cannot be infinity in that case). */ +/** Set r equal to the sum of a and b. If rzr is non-NULL this sets *rzr such that r->z == a->z * *rzr (a cannot be infinity in that case). */ static void secp256k1_gej_add_var(secp256k1_gej *r, const secp256k1_gej *a, const secp256k1_gej *b, secp256k1_fe *rzr); /** Set r equal to the sum of a and b (with b given in affine coordinates, and not infinity). */ @@ -110,7 +109,7 @@ static void secp256k1_gej_add_ge(secp256k1_gej *r, const secp256k1_gej *a, const /** Set r equal to the sum of a and b (with b given in affine coordinates). This is more efficient than secp256k1_gej_add_var. It is identical to secp256k1_gej_add_ge but without constant-time - guarantee, and b is allowed to be infinity. If rzr is non-NULL, r->z = a->z * *rzr (a cannot be infinity in that case). */ + guarantee, and b is allowed to be infinity. If rzr is non-NULL this sets *rzr such that r->z == a->z * *rzr (a cannot be infinity in that case). */ static void secp256k1_gej_add_ge_var(secp256k1_gej *r, const secp256k1_gej *a, const secp256k1_ge *b, secp256k1_fe *rzr); /** Set r equal to the sum of a and b (with the inverse of b's Z coordinate passed as bzinv). */ @@ -133,7 +132,7 @@ static void secp256k1_ge_to_storage(secp256k1_ge_storage *r, const secp256k1_ge /** Convert a group element back from the storage type. */ static void secp256k1_ge_from_storage(secp256k1_ge *r, const secp256k1_ge_storage *a); -/** If flag is true, set *r equal to *a; otherwise leave it. Constant-time. */ +/** If flag is true, set *r equal to *a; otherwise leave it. Constant-time. Both *r and *a must be initialized.*/ static void secp256k1_ge_storage_cmov(secp256k1_ge_storage *r, const secp256k1_ge_storage *a, int flag); /** Rescale a jacobian point by b which must be non-zero. Constant-time. */ diff --git a/src/secp256k1/src/group_impl.h b/src/secp256k1/src/group_impl.h index 9b93c39e92..43b039becf 100644 --- a/src/secp256k1/src/group_impl.h +++ b/src/secp256k1/src/group_impl.h @@ -303,7 +303,7 @@ static int secp256k1_ge_is_valid_var(const secp256k1_ge *a) { return secp256k1_fe_equal_var(&y2, &x3); } -static void secp256k1_gej_double_var(secp256k1_gej *r, const secp256k1_gej *a, secp256k1_fe *rzr) { +static SECP256K1_INLINE void secp256k1_gej_double_nonzero(secp256k1_gej *r, const secp256k1_gej *a) { /* Operations: 3 mul, 4 sqr, 0 normalize, 12 mul_int/add/negate. * * Note that there is an implementation described at @@ -312,29 +312,9 @@ static void secp256k1_gej_double_var(secp256k1_gej *r, const secp256k1_gej *a, s * mainly because it requires more normalizations. */ secp256k1_fe t1,t2,t3,t4; - /** For secp256k1, 2Q is infinity if and only if Q is infinity. This is because if 2Q = infinity, - * Q must equal -Q, or that Q.y == -(Q.y), or Q.y is 0. For a point on y^2 = x^3 + 7 to have - * y=0, x^3 must be -7 mod p. However, -7 has no cube root mod p. - * - * Having said this, if this function receives a point on a sextic twist, e.g. by - * a fault attack, it is possible for y to be 0. This happens for y^2 = x^3 + 6, - * since -6 does have a cube root mod p. For this point, this function will not set - * the infinity flag even though the point doubles to infinity, and the result - * point will be gibberish (z = 0 but infinity = 0). - */ - r->infinity = a->infinity; - if (r->infinity) { - if (rzr != NULL) { - secp256k1_fe_set_int(rzr, 1); - } - return; - } - if (rzr != NULL) { - *rzr = a->y; - secp256k1_fe_normalize_weak(rzr); - secp256k1_fe_mul_int(rzr, 2); - } + VERIFY_CHECK(!secp256k1_gej_is_infinity(a)); + r->infinity = 0; secp256k1_fe_mul(&r->z, &a->z, &a->y); secp256k1_fe_mul_int(&r->z, 2); /* Z' = 2*Y*Z (2) */ @@ -358,9 +338,32 @@ static void secp256k1_gej_double_var(secp256k1_gej *r, const secp256k1_gej *a, s secp256k1_fe_add(&r->y, &t2); /* Y' = 36*X^3*Y^2 - 27*X^6 - 8*Y^4 (4) */ } -static SECP256K1_INLINE void secp256k1_gej_double_nonzero(secp256k1_gej *r, const secp256k1_gej *a, secp256k1_fe *rzr) { - VERIFY_CHECK(!secp256k1_gej_is_infinity(a)); - secp256k1_gej_double_var(r, a, rzr); +static void secp256k1_gej_double_var(secp256k1_gej *r, const secp256k1_gej *a, secp256k1_fe *rzr) { + /** For secp256k1, 2Q is infinity if and only if Q is infinity. This is because if 2Q = infinity, + * Q must equal -Q, or that Q.y == -(Q.y), or Q.y is 0. For a point on y^2 = x^3 + 7 to have + * y=0, x^3 must be -7 mod p. However, -7 has no cube root mod p. + * + * Having said this, if this function receives a point on a sextic twist, e.g. by + * a fault attack, it is possible for y to be 0. This happens for y^2 = x^3 + 6, + * since -6 does have a cube root mod p. For this point, this function will not set + * the infinity flag even though the point doubles to infinity, and the result + * point will be gibberish (z = 0 but infinity = 0). + */ + if (a->infinity) { + r->infinity = 1; + if (rzr != NULL) { + secp256k1_fe_set_int(rzr, 1); + } + return; + } + + if (rzr != NULL) { + *rzr = a->y; + secp256k1_fe_normalize_weak(rzr); + secp256k1_fe_mul_int(rzr, 2); + } + + secp256k1_gej_double_nonzero(r, a); } static void secp256k1_gej_add_var(secp256k1_gej *r, const secp256k1_gej *a, const secp256k1_gej *b, secp256k1_fe *rzr) { diff --git a/src/secp256k1/src/hash_impl.h b/src/secp256k1/src/hash_impl.h index 009f26beba..782f97216c 100644 --- a/src/secp256k1/src/hash_impl.h +++ b/src/secp256k1/src/hash_impl.h @@ -131,7 +131,8 @@ static void secp256k1_sha256_transform(uint32_t* s, const uint32_t* chunk) { static void secp256k1_sha256_write(secp256k1_sha256 *hash, const unsigned char *data, size_t len) { size_t bufsize = hash->bytes & 0x3F; hash->bytes += len; - while (bufsize + len >= 64) { + VERIFY_CHECK(hash->bytes >= len); + while (len >= 64 - bufsize) { /* Fill the buffer, and process it. */ size_t chunk_len = 64 - bufsize; memcpy(((unsigned char*)hash->buf) + bufsize, data, chunk_len); diff --git a/src/secp256k1/src/java/org/bitcoin/NativeSecp256k1.java b/src/secp256k1/src/java/org/bitcoin/NativeSecp256k1.java deleted file mode 100644 index 1c67802fba..0000000000 --- a/src/secp256k1/src/java/org/bitcoin/NativeSecp256k1.java +++ /dev/null @@ -1,446 +0,0 @@ -/* - * Copyright 2013 Google Inc. - * Copyright 2014-2016 the libsecp256k1 contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.bitcoin; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -import java.math.BigInteger; -import com.google.common.base.Preconditions; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantReadWriteLock; -import static org.bitcoin.NativeSecp256k1Util.*; - -/** - * <p>This class holds native methods to handle ECDSA verification.</p> - * - * <p>You can find an example library that can be used for this at https://github.com/bitcoin/secp256k1</p> - * - * <p>To build secp256k1 for use with bitcoinj, run - * `./configure --enable-jni --enable-experimental --enable-module-ecdh` - * and `make` then copy `.libs/libsecp256k1.so` to your system library path - * or point the JVM to the folder containing it with -Djava.library.path - * </p> - */ -public class NativeSecp256k1 { - - private static final ReentrantReadWriteLock rwl = new ReentrantReadWriteLock(); - private static final Lock r = rwl.readLock(); - private static final Lock w = rwl.writeLock(); - private static ThreadLocal<ByteBuffer> nativeECDSABuffer = new ThreadLocal<ByteBuffer>(); - /** - * Verifies the given secp256k1 signature in native code. - * Calling when enabled == false is undefined (probably library not loaded) - * - * @param data The data which was signed, must be exactly 32 bytes - * @param signature The signature - * @param pub The public key which did the signing - */ - public static boolean verify(byte[] data, byte[] signature, byte[] pub) throws AssertFailException{ - Preconditions.checkArgument(data.length == 32 && signature.length <= 520 && pub.length <= 520); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < 520) { - byteBuff = ByteBuffer.allocateDirect(520); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(data); - byteBuff.put(signature); - byteBuff.put(pub); - - byte[][] retByteArray; - - r.lock(); - try { - return secp256k1_ecdsa_verify(byteBuff, Secp256k1Context.getContext(), signature.length, pub.length) == 1; - } finally { - r.unlock(); - } - } - - /** - * libsecp256k1 Create an ECDSA signature. - * - * @param data Message hash, 32 bytes - * @param key Secret key, 32 bytes - * - * Return values - * @param sig byte array of signature - */ - public static byte[] sign(byte[] data, byte[] sec) throws AssertFailException{ - Preconditions.checkArgument(data.length == 32 && sec.length <= 32); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < 32 + 32) { - byteBuff = ByteBuffer.allocateDirect(32 + 32); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(data); - byteBuff.put(sec); - - byte[][] retByteArray; - - r.lock(); - try { - retByteArray = secp256k1_ecdsa_sign(byteBuff, Secp256k1Context.getContext()); - } finally { - r.unlock(); - } - - byte[] sigArr = retByteArray[0]; - int sigLen = new BigInteger(new byte[] { retByteArray[1][0] }).intValue(); - int retVal = new BigInteger(new byte[] { retByteArray[1][1] }).intValue(); - - assertEquals(sigArr.length, sigLen, "Got bad signature length."); - - return retVal == 0 ? new byte[0] : sigArr; - } - - /** - * libsecp256k1 Seckey Verify - returns 1 if valid, 0 if invalid - * - * @param seckey ECDSA Secret key, 32 bytes - */ - public static boolean secKeyVerify(byte[] seckey) { - Preconditions.checkArgument(seckey.length == 32); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < seckey.length) { - byteBuff = ByteBuffer.allocateDirect(seckey.length); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(seckey); - - r.lock(); - try { - return secp256k1_ec_seckey_verify(byteBuff,Secp256k1Context.getContext()) == 1; - } finally { - r.unlock(); - } - } - - - /** - * libsecp256k1 Compute Pubkey - computes public key from secret key - * - * @param seckey ECDSA Secret key, 32 bytes - * - * Return values - * @param pubkey ECDSA Public key, 33 or 65 bytes - */ - //TODO add a 'compressed' arg - public static byte[] computePubkey(byte[] seckey) throws AssertFailException{ - Preconditions.checkArgument(seckey.length == 32); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < seckey.length) { - byteBuff = ByteBuffer.allocateDirect(seckey.length); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(seckey); - - byte[][] retByteArray; - - r.lock(); - try { - retByteArray = secp256k1_ec_pubkey_create(byteBuff, Secp256k1Context.getContext()); - } finally { - r.unlock(); - } - - byte[] pubArr = retByteArray[0]; - int pubLen = new BigInteger(new byte[] { retByteArray[1][0] }).intValue(); - int retVal = new BigInteger(new byte[] { retByteArray[1][1] }).intValue(); - - assertEquals(pubArr.length, pubLen, "Got bad pubkey length."); - - return retVal == 0 ? new byte[0]: pubArr; - } - - /** - * libsecp256k1 Cleanup - This destroys the secp256k1 context object - * This should be called at the end of the program for proper cleanup of the context. - */ - public static synchronized void cleanup() { - w.lock(); - try { - secp256k1_destroy_context(Secp256k1Context.getContext()); - } finally { - w.unlock(); - } - } - - public static long cloneContext() { - r.lock(); - try { - return secp256k1_ctx_clone(Secp256k1Context.getContext()); - } finally { r.unlock(); } - } - - /** - * libsecp256k1 PrivKey Tweak-Mul - Tweak privkey by multiplying to it - * - * @param tweak some bytes to tweak with - * @param seckey 32-byte seckey - */ - public static byte[] privKeyTweakMul(byte[] privkey, byte[] tweak) throws AssertFailException{ - Preconditions.checkArgument(privkey.length == 32); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < privkey.length + tweak.length) { - byteBuff = ByteBuffer.allocateDirect(privkey.length + tweak.length); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(privkey); - byteBuff.put(tweak); - - byte[][] retByteArray; - r.lock(); - try { - retByteArray = secp256k1_privkey_tweak_mul(byteBuff,Secp256k1Context.getContext()); - } finally { - r.unlock(); - } - - byte[] privArr = retByteArray[0]; - - int privLen = (byte) new BigInteger(new byte[] { retByteArray[1][0] }).intValue() & 0xFF; - int retVal = new BigInteger(new byte[] { retByteArray[1][1] }).intValue(); - - assertEquals(privArr.length, privLen, "Got bad pubkey length."); - - assertEquals(retVal, 1, "Failed return value check."); - - return privArr; - } - - /** - * libsecp256k1 PrivKey Tweak-Add - Tweak privkey by adding to it - * - * @param tweak some bytes to tweak with - * @param seckey 32-byte seckey - */ - public static byte[] privKeyTweakAdd(byte[] privkey, byte[] tweak) throws AssertFailException{ - Preconditions.checkArgument(privkey.length == 32); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < privkey.length + tweak.length) { - byteBuff = ByteBuffer.allocateDirect(privkey.length + tweak.length); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(privkey); - byteBuff.put(tweak); - - byte[][] retByteArray; - r.lock(); - try { - retByteArray = secp256k1_privkey_tweak_add(byteBuff,Secp256k1Context.getContext()); - } finally { - r.unlock(); - } - - byte[] privArr = retByteArray[0]; - - int privLen = (byte) new BigInteger(new byte[] { retByteArray[1][0] }).intValue() & 0xFF; - int retVal = new BigInteger(new byte[] { retByteArray[1][1] }).intValue(); - - assertEquals(privArr.length, privLen, "Got bad pubkey length."); - - assertEquals(retVal, 1, "Failed return value check."); - - return privArr; - } - - /** - * libsecp256k1 PubKey Tweak-Add - Tweak pubkey by adding to it - * - * @param tweak some bytes to tweak with - * @param pubkey 32-byte seckey - */ - public static byte[] pubKeyTweakAdd(byte[] pubkey, byte[] tweak) throws AssertFailException{ - Preconditions.checkArgument(pubkey.length == 33 || pubkey.length == 65); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < pubkey.length + tweak.length) { - byteBuff = ByteBuffer.allocateDirect(pubkey.length + tweak.length); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(pubkey); - byteBuff.put(tweak); - - byte[][] retByteArray; - r.lock(); - try { - retByteArray = secp256k1_pubkey_tweak_add(byteBuff,Secp256k1Context.getContext(), pubkey.length); - } finally { - r.unlock(); - } - - byte[] pubArr = retByteArray[0]; - - int pubLen = (byte) new BigInteger(new byte[] { retByteArray[1][0] }).intValue() & 0xFF; - int retVal = new BigInteger(new byte[] { retByteArray[1][1] }).intValue(); - - assertEquals(pubArr.length, pubLen, "Got bad pubkey length."); - - assertEquals(retVal, 1, "Failed return value check."); - - return pubArr; - } - - /** - * libsecp256k1 PubKey Tweak-Mul - Tweak pubkey by multiplying to it - * - * @param tweak some bytes to tweak with - * @param pubkey 32-byte seckey - */ - public static byte[] pubKeyTweakMul(byte[] pubkey, byte[] tweak) throws AssertFailException{ - Preconditions.checkArgument(pubkey.length == 33 || pubkey.length == 65); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < pubkey.length + tweak.length) { - byteBuff = ByteBuffer.allocateDirect(pubkey.length + tweak.length); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(pubkey); - byteBuff.put(tweak); - - byte[][] retByteArray; - r.lock(); - try { - retByteArray = secp256k1_pubkey_tweak_mul(byteBuff,Secp256k1Context.getContext(), pubkey.length); - } finally { - r.unlock(); - } - - byte[] pubArr = retByteArray[0]; - - int pubLen = (byte) new BigInteger(new byte[] { retByteArray[1][0] }).intValue() & 0xFF; - int retVal = new BigInteger(new byte[] { retByteArray[1][1] }).intValue(); - - assertEquals(pubArr.length, pubLen, "Got bad pubkey length."); - - assertEquals(retVal, 1, "Failed return value check."); - - return pubArr; - } - - /** - * libsecp256k1 create ECDH secret - constant time ECDH calculation - * - * @param seckey byte array of secret key used in exponentiaion - * @param pubkey byte array of public key used in exponentiaion - */ - public static byte[] createECDHSecret(byte[] seckey, byte[] pubkey) throws AssertFailException{ - Preconditions.checkArgument(seckey.length <= 32 && pubkey.length <= 65); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < 32 + pubkey.length) { - byteBuff = ByteBuffer.allocateDirect(32 + pubkey.length); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(seckey); - byteBuff.put(pubkey); - - byte[][] retByteArray; - r.lock(); - try { - retByteArray = secp256k1_ecdh(byteBuff, Secp256k1Context.getContext(), pubkey.length); - } finally { - r.unlock(); - } - - byte[] resArr = retByteArray[0]; - int retVal = new BigInteger(new byte[] { retByteArray[1][0] }).intValue(); - - assertEquals(resArr.length, 32, "Got bad result length."); - assertEquals(retVal, 1, "Failed return value check."); - - return resArr; - } - - /** - * libsecp256k1 randomize - updates the context randomization - * - * @param seed 32-byte random seed - */ - public static synchronized boolean randomize(byte[] seed) throws AssertFailException{ - Preconditions.checkArgument(seed.length == 32 || seed == null); - - ByteBuffer byteBuff = nativeECDSABuffer.get(); - if (byteBuff == null || byteBuff.capacity() < seed.length) { - byteBuff = ByteBuffer.allocateDirect(seed.length); - byteBuff.order(ByteOrder.nativeOrder()); - nativeECDSABuffer.set(byteBuff); - } - byteBuff.rewind(); - byteBuff.put(seed); - - w.lock(); - try { - return secp256k1_context_randomize(byteBuff, Secp256k1Context.getContext()) == 1; - } finally { - w.unlock(); - } - } - - private static native long secp256k1_ctx_clone(long context); - - private static native int secp256k1_context_randomize(ByteBuffer byteBuff, long context); - - private static native byte[][] secp256k1_privkey_tweak_add(ByteBuffer byteBuff, long context); - - private static native byte[][] secp256k1_privkey_tweak_mul(ByteBuffer byteBuff, long context); - - private static native byte[][] secp256k1_pubkey_tweak_add(ByteBuffer byteBuff, long context, int pubLen); - - private static native byte[][] secp256k1_pubkey_tweak_mul(ByteBuffer byteBuff, long context, int pubLen); - - private static native void secp256k1_destroy_context(long context); - - private static native int secp256k1_ecdsa_verify(ByteBuffer byteBuff, long context, int sigLen, int pubLen); - - private static native byte[][] secp256k1_ecdsa_sign(ByteBuffer byteBuff, long context); - - private static native int secp256k1_ec_seckey_verify(ByteBuffer byteBuff, long context); - - private static native byte[][] secp256k1_ec_pubkey_create(ByteBuffer byteBuff, long context); - - private static native byte[][] secp256k1_ec_pubkey_parse(ByteBuffer byteBuff, long context, int inputLen); - - private static native byte[][] secp256k1_ecdh(ByteBuffer byteBuff, long context, int inputLen); - -} diff --git a/src/secp256k1/src/java/org/bitcoin/NativeSecp256k1Test.java b/src/secp256k1/src/java/org/bitcoin/NativeSecp256k1Test.java deleted file mode 100644 index d766a1029c..0000000000 --- a/src/secp256k1/src/java/org/bitcoin/NativeSecp256k1Test.java +++ /dev/null @@ -1,226 +0,0 @@ -package org.bitcoin; - -import com.google.common.io.BaseEncoding; -import java.util.Arrays; -import java.math.BigInteger; -import javax.xml.bind.DatatypeConverter; -import static org.bitcoin.NativeSecp256k1Util.*; - -/** - * This class holds test cases defined for testing this library. - */ -public class NativeSecp256k1Test { - - //TODO improve comments/add more tests - /** - * This tests verify() for a valid signature - */ - public static void testVerifyPos() throws AssertFailException{ - boolean result = false; - byte[] data = BaseEncoding.base16().lowerCase().decode("CF80CD8AED482D5D1527D7DC72FCEFF84E6326592848447D2DC0B0E87DFC9A90".toLowerCase()); //sha256hash of "testing" - byte[] sig = BaseEncoding.base16().lowerCase().decode("3044022079BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F817980220294F14E883B3F525B5367756C2A11EF6CF84B730B36C17CB0C56F0AAB2C98589".toLowerCase()); - byte[] pub = BaseEncoding.base16().lowerCase().decode("040A629506E1B65CD9D2E0BA9C75DF9C4FED0DB16DC9625ED14397F0AFC836FAE595DC53F8B0EFE61E703075BD9B143BAC75EC0E19F82A2208CAEB32BE53414C40".toLowerCase()); - - result = NativeSecp256k1.verify( data, sig, pub); - assertEquals( result, true , "testVerifyPos"); - } - - /** - * This tests verify() for a non-valid signature - */ - public static void testVerifyNeg() throws AssertFailException{ - boolean result = false; - byte[] data = BaseEncoding.base16().lowerCase().decode("CF80CD8AED482D5D1527D7DC72FCEFF84E6326592848447D2DC0B0E87DFC9A91".toLowerCase()); //sha256hash of "testing" - byte[] sig = BaseEncoding.base16().lowerCase().decode("3044022079BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F817980220294F14E883B3F525B5367756C2A11EF6CF84B730B36C17CB0C56F0AAB2C98589".toLowerCase()); - byte[] pub = BaseEncoding.base16().lowerCase().decode("040A629506E1B65CD9D2E0BA9C75DF9C4FED0DB16DC9625ED14397F0AFC836FAE595DC53F8B0EFE61E703075BD9B143BAC75EC0E19F82A2208CAEB32BE53414C40".toLowerCase()); - - result = NativeSecp256k1.verify( data, sig, pub); - //System.out.println(" TEST " + new BigInteger(1, resultbytes).toString(16)); - assertEquals( result, false , "testVerifyNeg"); - } - - /** - * This tests secret key verify() for a valid secretkey - */ - public static void testSecKeyVerifyPos() throws AssertFailException{ - boolean result = false; - byte[] sec = BaseEncoding.base16().lowerCase().decode("67E56582298859DDAE725F972992A07C6C4FB9F62A8FFF58CE3CA926A1063530".toLowerCase()); - - result = NativeSecp256k1.secKeyVerify( sec ); - //System.out.println(" TEST " + new BigInteger(1, resultbytes).toString(16)); - assertEquals( result, true , "testSecKeyVerifyPos"); - } - - /** - * This tests secret key verify() for an invalid secretkey - */ - public static void testSecKeyVerifyNeg() throws AssertFailException{ - boolean result = false; - byte[] sec = BaseEncoding.base16().lowerCase().decode("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase()); - - result = NativeSecp256k1.secKeyVerify( sec ); - //System.out.println(" TEST " + new BigInteger(1, resultbytes).toString(16)); - assertEquals( result, false , "testSecKeyVerifyNeg"); - } - - /** - * This tests public key create() for a valid secretkey - */ - public static void testPubKeyCreatePos() throws AssertFailException{ - byte[] sec = BaseEncoding.base16().lowerCase().decode("67E56582298859DDAE725F972992A07C6C4FB9F62A8FFF58CE3CA926A1063530".toLowerCase()); - - byte[] resultArr = NativeSecp256k1.computePubkey( sec); - String pubkeyString = javax.xml.bind.DatatypeConverter.printHexBinary(resultArr); - assertEquals( pubkeyString , "04C591A8FF19AC9C4E4E5793673B83123437E975285E7B442F4EE2654DFFCA5E2D2103ED494718C697AC9AEBCFD19612E224DB46661011863ED2FC54E71861E2A6" , "testPubKeyCreatePos"); - } - - /** - * This tests public key create() for a invalid secretkey - */ - public static void testPubKeyCreateNeg() throws AssertFailException{ - byte[] sec = BaseEncoding.base16().lowerCase().decode("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase()); - - byte[] resultArr = NativeSecp256k1.computePubkey( sec); - String pubkeyString = javax.xml.bind.DatatypeConverter.printHexBinary(resultArr); - assertEquals( pubkeyString, "" , "testPubKeyCreateNeg"); - } - - /** - * This tests sign() for a valid secretkey - */ - public static void testSignPos() throws AssertFailException{ - - byte[] data = BaseEncoding.base16().lowerCase().decode("CF80CD8AED482D5D1527D7DC72FCEFF84E6326592848447D2DC0B0E87DFC9A90".toLowerCase()); //sha256hash of "testing" - byte[] sec = BaseEncoding.base16().lowerCase().decode("67E56582298859DDAE725F972992A07C6C4FB9F62A8FFF58CE3CA926A1063530".toLowerCase()); - - byte[] resultArr = NativeSecp256k1.sign(data, sec); - String sigString = javax.xml.bind.DatatypeConverter.printHexBinary(resultArr); - assertEquals( sigString, "30440220182A108E1448DC8F1FB467D06A0F3BB8EA0533584CB954EF8DA112F1D60E39A202201C66F36DA211C087F3AF88B50EDF4F9BDAA6CF5FD6817E74DCA34DB12390C6E9" , "testSignPos"); - } - - /** - * This tests sign() for a invalid secretkey - */ - public static void testSignNeg() throws AssertFailException{ - byte[] data = BaseEncoding.base16().lowerCase().decode("CF80CD8AED482D5D1527D7DC72FCEFF84E6326592848447D2DC0B0E87DFC9A90".toLowerCase()); //sha256hash of "testing" - byte[] sec = BaseEncoding.base16().lowerCase().decode("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase()); - - byte[] resultArr = NativeSecp256k1.sign(data, sec); - String sigString = javax.xml.bind.DatatypeConverter.printHexBinary(resultArr); - assertEquals( sigString, "" , "testSignNeg"); - } - - /** - * This tests private key tweak-add - */ - public static void testPrivKeyTweakAdd_1() throws AssertFailException { - byte[] sec = BaseEncoding.base16().lowerCase().decode("67E56582298859DDAE725F972992A07C6C4FB9F62A8FFF58CE3CA926A1063530".toLowerCase()); - byte[] data = BaseEncoding.base16().lowerCase().decode("3982F19BEF1615BCCFBB05E321C10E1D4CBA3DF0E841C2E41EEB6016347653C3".toLowerCase()); //sha256hash of "tweak" - - byte[] resultArr = NativeSecp256k1.privKeyTweakAdd( sec , data ); - String sigString = javax.xml.bind.DatatypeConverter.printHexBinary(resultArr); - assertEquals( sigString , "A168571E189E6F9A7E2D657A4B53AE99B909F7E712D1C23CED28093CD57C88F3" , "testPrivKeyAdd_1"); - } - - /** - * This tests private key tweak-mul - */ - public static void testPrivKeyTweakMul_1() throws AssertFailException { - byte[] sec = BaseEncoding.base16().lowerCase().decode("67E56582298859DDAE725F972992A07C6C4FB9F62A8FFF58CE3CA926A1063530".toLowerCase()); - byte[] data = BaseEncoding.base16().lowerCase().decode("3982F19BEF1615BCCFBB05E321C10E1D4CBA3DF0E841C2E41EEB6016347653C3".toLowerCase()); //sha256hash of "tweak" - - byte[] resultArr = NativeSecp256k1.privKeyTweakMul( sec , data ); - String sigString = javax.xml.bind.DatatypeConverter.printHexBinary(resultArr); - assertEquals( sigString , "97F8184235F101550F3C71C927507651BD3F1CDB4A5A33B8986ACF0DEE20FFFC" , "testPrivKeyMul_1"); - } - - /** - * This tests private key tweak-add uncompressed - */ - public static void testPrivKeyTweakAdd_2() throws AssertFailException { - byte[] pub = BaseEncoding.base16().lowerCase().decode("040A629506E1B65CD9D2E0BA9C75DF9C4FED0DB16DC9625ED14397F0AFC836FAE595DC53F8B0EFE61E703075BD9B143BAC75EC0E19F82A2208CAEB32BE53414C40".toLowerCase()); - byte[] data = BaseEncoding.base16().lowerCase().decode("3982F19BEF1615BCCFBB05E321C10E1D4CBA3DF0E841C2E41EEB6016347653C3".toLowerCase()); //sha256hash of "tweak" - - byte[] resultArr = NativeSecp256k1.pubKeyTweakAdd( pub , data ); - String sigString = javax.xml.bind.DatatypeConverter.printHexBinary(resultArr); - assertEquals( sigString , "0411C6790F4B663CCE607BAAE08C43557EDC1A4D11D88DFCB3D841D0C6A941AF525A268E2A863C148555C48FB5FBA368E88718A46E205FABC3DBA2CCFFAB0796EF" , "testPrivKeyAdd_2"); - } - - /** - * This tests private key tweak-mul uncompressed - */ - public static void testPrivKeyTweakMul_2() throws AssertFailException { - byte[] pub = BaseEncoding.base16().lowerCase().decode("040A629506E1B65CD9D2E0BA9C75DF9C4FED0DB16DC9625ED14397F0AFC836FAE595DC53F8B0EFE61E703075BD9B143BAC75EC0E19F82A2208CAEB32BE53414C40".toLowerCase()); - byte[] data = BaseEncoding.base16().lowerCase().decode("3982F19BEF1615BCCFBB05E321C10E1D4CBA3DF0E841C2E41EEB6016347653C3".toLowerCase()); //sha256hash of "tweak" - - byte[] resultArr = NativeSecp256k1.pubKeyTweakMul( pub , data ); - String sigString = javax.xml.bind.DatatypeConverter.printHexBinary(resultArr); - assertEquals( sigString , "04E0FE6FE55EBCA626B98A807F6CAF654139E14E5E3698F01A9A658E21DC1D2791EC060D4F412A794D5370F672BC94B722640B5F76914151CFCA6E712CA48CC589" , "testPrivKeyMul_2"); - } - - /** - * This tests seed randomization - */ - public static void testRandomize() throws AssertFailException { - byte[] seed = BaseEncoding.base16().lowerCase().decode("A441B15FE9A3CF56661190A0B93B9DEC7D04127288CC87250967CF3B52894D11".toLowerCase()); //sha256hash of "random" - boolean result = NativeSecp256k1.randomize(seed); - assertEquals( result, true, "testRandomize"); - } - - public static void testCreateECDHSecret() throws AssertFailException{ - - byte[] sec = BaseEncoding.base16().lowerCase().decode("67E56582298859DDAE725F972992A07C6C4FB9F62A8FFF58CE3CA926A1063530".toLowerCase()); - byte[] pub = BaseEncoding.base16().lowerCase().decode("040A629506E1B65CD9D2E0BA9C75DF9C4FED0DB16DC9625ED14397F0AFC836FAE595DC53F8B0EFE61E703075BD9B143BAC75EC0E19F82A2208CAEB32BE53414C40".toLowerCase()); - - byte[] resultArr = NativeSecp256k1.createECDHSecret(sec, pub); - String ecdhString = javax.xml.bind.DatatypeConverter.printHexBinary(resultArr); - assertEquals( ecdhString, "2A2A67007A926E6594AF3EB564FC74005B37A9C8AEF2033C4552051B5C87F043" , "testCreateECDHSecret"); - } - - public static void main(String[] args) throws AssertFailException{ - - - System.out.println("\n libsecp256k1 enabled: " + Secp256k1Context.isEnabled() + "\n"); - - assertEquals( Secp256k1Context.isEnabled(), true, "isEnabled" ); - - //Test verify() success/fail - testVerifyPos(); - testVerifyNeg(); - - //Test secKeyVerify() success/fail - testSecKeyVerifyPos(); - testSecKeyVerifyNeg(); - - //Test computePubkey() success/fail - testPubKeyCreatePos(); - testPubKeyCreateNeg(); - - //Test sign() success/fail - testSignPos(); - testSignNeg(); - - //Test privKeyTweakAdd() 1 - testPrivKeyTweakAdd_1(); - - //Test privKeyTweakMul() 2 - testPrivKeyTweakMul_1(); - - //Test privKeyTweakAdd() 3 - testPrivKeyTweakAdd_2(); - - //Test privKeyTweakMul() 4 - testPrivKeyTweakMul_2(); - - //Test randomize() - testRandomize(); - - //Test ECDH - testCreateECDHSecret(); - - NativeSecp256k1.cleanup(); - - System.out.println(" All tests passed." ); - - } -} diff --git a/src/secp256k1/src/java/org/bitcoin/NativeSecp256k1Util.java b/src/secp256k1/src/java/org/bitcoin/NativeSecp256k1Util.java deleted file mode 100644 index 04732ba044..0000000000 --- a/src/secp256k1/src/java/org/bitcoin/NativeSecp256k1Util.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2014-2016 the libsecp256k1 contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.bitcoin; - -public class NativeSecp256k1Util{ - - public static void assertEquals( int val, int val2, String message ) throws AssertFailException{ - if( val != val2 ) - throw new AssertFailException("FAIL: " + message); - } - - public static void assertEquals( boolean val, boolean val2, String message ) throws AssertFailException{ - if( val != val2 ) - throw new AssertFailException("FAIL: " + message); - else - System.out.println("PASS: " + message); - } - - public static void assertEquals( String val, String val2, String message ) throws AssertFailException{ - if( !val.equals(val2) ) - throw new AssertFailException("FAIL: " + message); - else - System.out.println("PASS: " + message); - } - - public static class AssertFailException extends Exception { - public AssertFailException(String message) { - super( message ); - } - } -} diff --git a/src/secp256k1/src/java/org/bitcoin/Secp256k1Context.java b/src/secp256k1/src/java/org/bitcoin/Secp256k1Context.java deleted file mode 100644 index 216c986a8b..0000000000 --- a/src/secp256k1/src/java/org/bitcoin/Secp256k1Context.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2014-2016 the libsecp256k1 contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.bitcoin; - -/** - * This class holds the context reference used in native methods - * to handle ECDSA operations. - */ -public class Secp256k1Context { - private static final boolean enabled; //true if the library is loaded - private static final long context; //ref to pointer to context obj - - static { //static initializer - boolean isEnabled = true; - long contextRef = -1; - try { - System.loadLibrary("secp256k1"); - contextRef = secp256k1_init_context(); - } catch (UnsatisfiedLinkError e) { - System.out.println("UnsatisfiedLinkError: " + e.toString()); - isEnabled = false; - } - enabled = isEnabled; - context = contextRef; - } - - public static boolean isEnabled() { - return enabled; - } - - public static long getContext() { - if(!enabled) return -1; //sanity check - return context; - } - - private static native long secp256k1_init_context(); -} diff --git a/src/secp256k1/src/java/org_bitcoin_NativeSecp256k1.c b/src/secp256k1/src/java/org_bitcoin_NativeSecp256k1.c deleted file mode 100644 index b50970b4f2..0000000000 --- a/src/secp256k1/src/java/org_bitcoin_NativeSecp256k1.c +++ /dev/null @@ -1,379 +0,0 @@ -#include <stdlib.h> -#include <stdint.h> -#include <string.h> -#include "org_bitcoin_NativeSecp256k1.h" -#include "include/secp256k1.h" -#include "include/secp256k1_ecdh.h" -#include "include/secp256k1_recovery.h" - - -SECP256K1_API jlong JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ctx_1clone - (JNIEnv* env, jclass classObject, jlong ctx_l) -{ - const secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - - jlong ctx_clone_l = (uintptr_t) secp256k1_context_clone(ctx); - - (void)classObject;(void)env; - - return ctx_clone_l; - -} - -SECP256K1_API jint JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1context_1randomize - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - - const unsigned char* seed = (unsigned char*) (*env)->GetDirectBufferAddress(env, byteBufferObject); - - (void)classObject; - - return secp256k1_context_randomize(ctx, seed); - -} - -SECP256K1_API void JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1destroy_1context - (JNIEnv* env, jclass classObject, jlong ctx_l) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - - secp256k1_context_destroy(ctx); - - (void)classObject;(void)env; -} - -SECP256K1_API jint JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ecdsa_1verify - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l, jint siglen, jint publen) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - - unsigned char* data = (unsigned char*) (*env)->GetDirectBufferAddress(env, byteBufferObject); - const unsigned char* sigdata = { (unsigned char*) (data + 32) }; - const unsigned char* pubdata = { (unsigned char*) (data + siglen + 32) }; - - secp256k1_ecdsa_signature sig; - secp256k1_pubkey pubkey; - - int ret = secp256k1_ecdsa_signature_parse_der(ctx, &sig, sigdata, siglen); - - if( ret ) { - ret = secp256k1_ec_pubkey_parse(ctx, &pubkey, pubdata, publen); - - if( ret ) { - ret = secp256k1_ecdsa_verify(ctx, &sig, data, &pubkey); - } - } - - (void)classObject; - - return ret; -} - -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ecdsa_1sign - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - unsigned char* data = (unsigned char*) (*env)->GetDirectBufferAddress(env, byteBufferObject); - unsigned char* secKey = (unsigned char*) (data + 32); - - jobjectArray retArray; - jbyteArray sigArray, intsByteArray; - unsigned char intsarray[2]; - - secp256k1_ecdsa_signature sig[72]; - - int ret = secp256k1_ecdsa_sign(ctx, sig, data, secKey, NULL, NULL); - - unsigned char outputSer[72]; - size_t outputLen = 72; - - if( ret ) { - int ret2 = secp256k1_ecdsa_signature_serialize_der(ctx,outputSer, &outputLen, sig ); (void)ret2; - } - - intsarray[0] = outputLen; - intsarray[1] = ret; - - retArray = (*env)->NewObjectArray(env, 2, - (*env)->FindClass(env, "[B"), - (*env)->NewByteArray(env, 1)); - - sigArray = (*env)->NewByteArray(env, outputLen); - (*env)->SetByteArrayRegion(env, sigArray, 0, outputLen, (jbyte*)outputSer); - (*env)->SetObjectArrayElement(env, retArray, 0, sigArray); - - intsByteArray = (*env)->NewByteArray(env, 2); - (*env)->SetByteArrayRegion(env, intsByteArray, 0, 2, (jbyte*)intsarray); - (*env)->SetObjectArrayElement(env, retArray, 1, intsByteArray); - - (void)classObject; - - return retArray; -} - -SECP256K1_API jint JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ec_1seckey_1verify - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - unsigned char* secKey = (unsigned char*) (*env)->GetDirectBufferAddress(env, byteBufferObject); - - (void)classObject; - - return secp256k1_ec_seckey_verify(ctx, secKey); -} - -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ec_1pubkey_1create - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - const unsigned char* secKey = (unsigned char*) (*env)->GetDirectBufferAddress(env, byteBufferObject); - - secp256k1_pubkey pubkey; - - jobjectArray retArray; - jbyteArray pubkeyArray, intsByteArray; - unsigned char intsarray[2]; - - int ret = secp256k1_ec_pubkey_create(ctx, &pubkey, secKey); - - unsigned char outputSer[65]; - size_t outputLen = 65; - - if( ret ) { - int ret2 = secp256k1_ec_pubkey_serialize(ctx,outputSer, &outputLen, &pubkey,SECP256K1_EC_UNCOMPRESSED );(void)ret2; - } - - intsarray[0] = outputLen; - intsarray[1] = ret; - - retArray = (*env)->NewObjectArray(env, 2, - (*env)->FindClass(env, "[B"), - (*env)->NewByteArray(env, 1)); - - pubkeyArray = (*env)->NewByteArray(env, outputLen); - (*env)->SetByteArrayRegion(env, pubkeyArray, 0, outputLen, (jbyte*)outputSer); - (*env)->SetObjectArrayElement(env, retArray, 0, pubkeyArray); - - intsByteArray = (*env)->NewByteArray(env, 2); - (*env)->SetByteArrayRegion(env, intsByteArray, 0, 2, (jbyte*)intsarray); - (*env)->SetObjectArrayElement(env, retArray, 1, intsByteArray); - - (void)classObject; - - return retArray; - -} - -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1privkey_1tweak_1add - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - unsigned char* privkey = (unsigned char*) (*env)->GetDirectBufferAddress(env, byteBufferObject); - const unsigned char* tweak = (unsigned char*) (privkey + 32); - - jobjectArray retArray; - jbyteArray privArray, intsByteArray; - unsigned char intsarray[2]; - - int privkeylen = 32; - - int ret = secp256k1_ec_privkey_tweak_add(ctx, privkey, tweak); - - intsarray[0] = privkeylen; - intsarray[1] = ret; - - retArray = (*env)->NewObjectArray(env, 2, - (*env)->FindClass(env, "[B"), - (*env)->NewByteArray(env, 1)); - - privArray = (*env)->NewByteArray(env, privkeylen); - (*env)->SetByteArrayRegion(env, privArray, 0, privkeylen, (jbyte*)privkey); - (*env)->SetObjectArrayElement(env, retArray, 0, privArray); - - intsByteArray = (*env)->NewByteArray(env, 2); - (*env)->SetByteArrayRegion(env, intsByteArray, 0, 2, (jbyte*)intsarray); - (*env)->SetObjectArrayElement(env, retArray, 1, intsByteArray); - - (void)classObject; - - return retArray; -} - -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1privkey_1tweak_1mul - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - unsigned char* privkey = (unsigned char*) (*env)->GetDirectBufferAddress(env, byteBufferObject); - const unsigned char* tweak = (unsigned char*) (privkey + 32); - - jobjectArray retArray; - jbyteArray privArray, intsByteArray; - unsigned char intsarray[2]; - - int privkeylen = 32; - - int ret = secp256k1_ec_privkey_tweak_mul(ctx, privkey, tweak); - - intsarray[0] = privkeylen; - intsarray[1] = ret; - - retArray = (*env)->NewObjectArray(env, 2, - (*env)->FindClass(env, "[B"), - (*env)->NewByteArray(env, 1)); - - privArray = (*env)->NewByteArray(env, privkeylen); - (*env)->SetByteArrayRegion(env, privArray, 0, privkeylen, (jbyte*)privkey); - (*env)->SetObjectArrayElement(env, retArray, 0, privArray); - - intsByteArray = (*env)->NewByteArray(env, 2); - (*env)->SetByteArrayRegion(env, intsByteArray, 0, 2, (jbyte*)intsarray); - (*env)->SetObjectArrayElement(env, retArray, 1, intsByteArray); - - (void)classObject; - - return retArray; -} - -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1pubkey_1tweak_1add - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l, jint publen) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; -/* secp256k1_pubkey* pubkey = (secp256k1_pubkey*) (*env)->GetDirectBufferAddress(env, byteBufferObject);*/ - unsigned char* pkey = (*env)->GetDirectBufferAddress(env, byteBufferObject); - const unsigned char* tweak = (unsigned char*) (pkey + publen); - - jobjectArray retArray; - jbyteArray pubArray, intsByteArray; - unsigned char intsarray[2]; - unsigned char outputSer[65]; - size_t outputLen = 65; - - secp256k1_pubkey pubkey; - int ret = secp256k1_ec_pubkey_parse(ctx, &pubkey, pkey, publen); - - if( ret ) { - ret = secp256k1_ec_pubkey_tweak_add(ctx, &pubkey, tweak); - } - - if( ret ) { - int ret2 = secp256k1_ec_pubkey_serialize(ctx,outputSer, &outputLen, &pubkey,SECP256K1_EC_UNCOMPRESSED );(void)ret2; - } - - intsarray[0] = outputLen; - intsarray[1] = ret; - - retArray = (*env)->NewObjectArray(env, 2, - (*env)->FindClass(env, "[B"), - (*env)->NewByteArray(env, 1)); - - pubArray = (*env)->NewByteArray(env, outputLen); - (*env)->SetByteArrayRegion(env, pubArray, 0, outputLen, (jbyte*)outputSer); - (*env)->SetObjectArrayElement(env, retArray, 0, pubArray); - - intsByteArray = (*env)->NewByteArray(env, 2); - (*env)->SetByteArrayRegion(env, intsByteArray, 0, 2, (jbyte*)intsarray); - (*env)->SetObjectArrayElement(env, retArray, 1, intsByteArray); - - (void)classObject; - - return retArray; -} - -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1pubkey_1tweak_1mul - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l, jint publen) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - unsigned char* pkey = (*env)->GetDirectBufferAddress(env, byteBufferObject); - const unsigned char* tweak = (unsigned char*) (pkey + publen); - - jobjectArray retArray; - jbyteArray pubArray, intsByteArray; - unsigned char intsarray[2]; - unsigned char outputSer[65]; - size_t outputLen = 65; - - secp256k1_pubkey pubkey; - int ret = secp256k1_ec_pubkey_parse(ctx, &pubkey, pkey, publen); - - if ( ret ) { - ret = secp256k1_ec_pubkey_tweak_mul(ctx, &pubkey, tweak); - } - - if( ret ) { - int ret2 = secp256k1_ec_pubkey_serialize(ctx,outputSer, &outputLen, &pubkey,SECP256K1_EC_UNCOMPRESSED );(void)ret2; - } - - intsarray[0] = outputLen; - intsarray[1] = ret; - - retArray = (*env)->NewObjectArray(env, 2, - (*env)->FindClass(env, "[B"), - (*env)->NewByteArray(env, 1)); - - pubArray = (*env)->NewByteArray(env, outputLen); - (*env)->SetByteArrayRegion(env, pubArray, 0, outputLen, (jbyte*)outputSer); - (*env)->SetObjectArrayElement(env, retArray, 0, pubArray); - - intsByteArray = (*env)->NewByteArray(env, 2); - (*env)->SetByteArrayRegion(env, intsByteArray, 0, 2, (jbyte*)intsarray); - (*env)->SetObjectArrayElement(env, retArray, 1, intsByteArray); - - (void)classObject; - - return retArray; -} - -SECP256K1_API jlong JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ecdsa_1pubkey_1combine - (JNIEnv * env, jclass classObject, jobject byteBufferObject, jlong ctx_l, jint numkeys) -{ - (void)classObject;(void)env;(void)byteBufferObject;(void)ctx_l;(void)numkeys; - - return 0; -} - -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ecdh - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l, jint publen) -{ - secp256k1_context *ctx = (secp256k1_context*)(uintptr_t)ctx_l; - const unsigned char* secdata = (*env)->GetDirectBufferAddress(env, byteBufferObject); - const unsigned char* pubdata = (const unsigned char*) (secdata + 32); - - jobjectArray retArray; - jbyteArray outArray, intsByteArray; - unsigned char intsarray[1]; - secp256k1_pubkey pubkey; - unsigned char nonce_res[32]; - size_t outputLen = 32; - - int ret = secp256k1_ec_pubkey_parse(ctx, &pubkey, pubdata, publen); - - if (ret) { - ret = secp256k1_ecdh( - ctx, - nonce_res, - &pubkey, - secdata, - NULL, - NULL - ); - } - - intsarray[0] = ret; - - retArray = (*env)->NewObjectArray(env, 2, - (*env)->FindClass(env, "[B"), - (*env)->NewByteArray(env, 1)); - - outArray = (*env)->NewByteArray(env, outputLen); - (*env)->SetByteArrayRegion(env, outArray, 0, 32, (jbyte*)nonce_res); - (*env)->SetObjectArrayElement(env, retArray, 0, outArray); - - intsByteArray = (*env)->NewByteArray(env, 1); - (*env)->SetByteArrayRegion(env, intsByteArray, 0, 1, (jbyte*)intsarray); - (*env)->SetObjectArrayElement(env, retArray, 1, intsByteArray); - - (void)classObject; - - return retArray; -} diff --git a/src/secp256k1/src/java/org_bitcoin_NativeSecp256k1.h b/src/secp256k1/src/java/org_bitcoin_NativeSecp256k1.h deleted file mode 100644 index fe613c9e9e..0000000000 --- a/src/secp256k1/src/java/org_bitcoin_NativeSecp256k1.h +++ /dev/null @@ -1,119 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include <jni.h> -#include "include/secp256k1.h" -/* Header for class org_bitcoin_NativeSecp256k1 */ - -#ifndef _Included_org_bitcoin_NativeSecp256k1 -#define _Included_org_bitcoin_NativeSecp256k1 -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_ctx_clone - * Signature: (J)J - */ -SECP256K1_API jlong JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ctx_1clone - (JNIEnv *, jclass, jlong); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_context_randomize - * Signature: (Ljava/nio/ByteBuffer;J)I - */ -SECP256K1_API jint JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1context_1randomize - (JNIEnv *, jclass, jobject, jlong); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_privkey_tweak_add - * Signature: (Ljava/nio/ByteBuffer;J)[[B - */ -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1privkey_1tweak_1add - (JNIEnv *, jclass, jobject, jlong); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_privkey_tweak_mul - * Signature: (Ljava/nio/ByteBuffer;J)[[B - */ -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1privkey_1tweak_1mul - (JNIEnv *, jclass, jobject, jlong); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_pubkey_tweak_add - * Signature: (Ljava/nio/ByteBuffer;JI)[[B - */ -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1pubkey_1tweak_1add - (JNIEnv *, jclass, jobject, jlong, jint); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_pubkey_tweak_mul - * Signature: (Ljava/nio/ByteBuffer;JI)[[B - */ -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1pubkey_1tweak_1mul - (JNIEnv *, jclass, jobject, jlong, jint); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_destroy_context - * Signature: (J)V - */ -SECP256K1_API void JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1destroy_1context - (JNIEnv *, jclass, jlong); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_ecdsa_verify - * Signature: (Ljava/nio/ByteBuffer;JII)I - */ -SECP256K1_API jint JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ecdsa_1verify - (JNIEnv *, jclass, jobject, jlong, jint, jint); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_ecdsa_sign - * Signature: (Ljava/nio/ByteBuffer;J)[[B - */ -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ecdsa_1sign - (JNIEnv *, jclass, jobject, jlong); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_ec_seckey_verify - * Signature: (Ljava/nio/ByteBuffer;J)I - */ -SECP256K1_API jint JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ec_1seckey_1verify - (JNIEnv *, jclass, jobject, jlong); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_ec_pubkey_create - * Signature: (Ljava/nio/ByteBuffer;J)[[B - */ -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ec_1pubkey_1create - (JNIEnv *, jclass, jobject, jlong); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_ec_pubkey_parse - * Signature: (Ljava/nio/ByteBuffer;JI)[[B - */ -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ec_1pubkey_1parse - (JNIEnv *, jclass, jobject, jlong, jint); - -/* - * Class: org_bitcoin_NativeSecp256k1 - * Method: secp256k1_ecdh - * Signature: (Ljava/nio/ByteBuffer;JI)[[B - */ -SECP256K1_API jobjectArray JNICALL Java_org_bitcoin_NativeSecp256k1_secp256k1_1ecdh - (JNIEnv* env, jclass classObject, jobject byteBufferObject, jlong ctx_l, jint publen); - - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/secp256k1/src/java/org_bitcoin_Secp256k1Context.c b/src/secp256k1/src/java/org_bitcoin_Secp256k1Context.c deleted file mode 100644 index a52939e7e7..0000000000 --- a/src/secp256k1/src/java/org_bitcoin_Secp256k1Context.c +++ /dev/null @@ -1,15 +0,0 @@ -#include <stdlib.h> -#include <stdint.h> -#include "org_bitcoin_Secp256k1Context.h" -#include "include/secp256k1.h" - -SECP256K1_API jlong JNICALL Java_org_bitcoin_Secp256k1Context_secp256k1_1init_1context - (JNIEnv* env, jclass classObject) -{ - secp256k1_context *ctx = secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); - - (void)classObject;(void)env; - - return (uintptr_t)ctx; -} - diff --git a/src/secp256k1/src/java/org_bitcoin_Secp256k1Context.h b/src/secp256k1/src/java/org_bitcoin_Secp256k1Context.h deleted file mode 100644 index 0d2bc84b7f..0000000000 --- a/src/secp256k1/src/java/org_bitcoin_Secp256k1Context.h +++ /dev/null @@ -1,22 +0,0 @@ -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include <jni.h> -#include "include/secp256k1.h" -/* Header for class org_bitcoin_Secp256k1Context */ - -#ifndef _Included_org_bitcoin_Secp256k1Context -#define _Included_org_bitcoin_Secp256k1Context -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_bitcoin_Secp256k1Context - * Method: secp256k1_init_context - * Signature: ()J - */ -SECP256K1_API jlong JNICALL Java_org_bitcoin_Secp256k1Context_secp256k1_1init_1context - (JNIEnv *, jclass); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/src/secp256k1/src/modules/ecdh/main_impl.h b/src/secp256k1/src/modules/ecdh/main_impl.h index 44cb68e750..07a25b80d4 100644 --- a/src/secp256k1/src/modules/ecdh/main_impl.h +++ b/src/secp256k1/src/modules/ecdh/main_impl.h @@ -10,14 +10,14 @@ #include "include/secp256k1_ecdh.h" #include "ecmult_const_impl.h" -static int ecdh_hash_function_sha256(unsigned char *output, const unsigned char *x, const unsigned char *y, void *data) { - unsigned char version = (y[31] & 0x01) | 0x02; +static int ecdh_hash_function_sha256(unsigned char *output, const unsigned char *x32, const unsigned char *y32, void *data) { + unsigned char version = (y32[31] & 0x01) | 0x02; secp256k1_sha256 sha; (void)data; secp256k1_sha256_initialize(&sha); secp256k1_sha256_write(&sha, &version, 1); - secp256k1_sha256_write(&sha, x, 32); + secp256k1_sha256_write(&sha, x32, 32); secp256k1_sha256_finalize(&sha, output); return 1; @@ -32,36 +32,40 @@ int secp256k1_ecdh(const secp256k1_context* ctx, unsigned char *output, const se secp256k1_gej res; secp256k1_ge pt; secp256k1_scalar s; + unsigned char x[32]; + unsigned char y[32]; + VERIFY_CHECK(ctx != NULL); ARG_CHECK(output != NULL); ARG_CHECK(point != NULL); ARG_CHECK(scalar != NULL); + if (hashfp == NULL) { hashfp = secp256k1_ecdh_hash_function_default; } secp256k1_pubkey_load(ctx, &pt, point); secp256k1_scalar_set_b32(&s, scalar, &overflow); - if (overflow || secp256k1_scalar_is_zero(&s)) { - ret = 0; - } else { - unsigned char x[32]; - unsigned char y[32]; - - secp256k1_ecmult_const(&res, &pt, &s, 256); - secp256k1_ge_set_gej(&pt, &res); - - /* Compute a hash of the point */ - secp256k1_fe_normalize(&pt.x); - secp256k1_fe_normalize(&pt.y); - secp256k1_fe_get_b32(x, &pt.x); - secp256k1_fe_get_b32(y, &pt.y); - - ret = hashfp(output, x, y, data); - } + overflow |= secp256k1_scalar_is_zero(&s); + secp256k1_scalar_cmov(&s, &secp256k1_scalar_one, overflow); + + secp256k1_ecmult_const(&res, &pt, &s, 256); + secp256k1_ge_set_gej(&pt, &res); + + /* Compute a hash of the point */ + secp256k1_fe_normalize(&pt.x); + secp256k1_fe_normalize(&pt.y); + secp256k1_fe_get_b32(x, &pt.x); + secp256k1_fe_get_b32(y, &pt.y); + + ret = hashfp(output, x, y, data); + + memset(x, 0, 32); + memset(y, 0, 32); secp256k1_scalar_clear(&s); - return ret; + + return !!ret & !overflow; } #endif /* SECP256K1_MODULE_ECDH_MAIN_H */ diff --git a/src/secp256k1/src/modules/recovery/main_impl.h b/src/secp256k1/src/modules/recovery/main_impl.h index 2f6691c5a1..e2576aa953 100755..100644 --- a/src/secp256k1/src/modules/recovery/main_impl.h +++ b/src/secp256k1/src/modules/recovery/main_impl.h @@ -122,48 +122,15 @@ static int secp256k1_ecdsa_sig_recover(const secp256k1_ecmult_context *ctx, cons int secp256k1_ecdsa_sign_recoverable(const secp256k1_context* ctx, secp256k1_ecdsa_recoverable_signature *signature, const unsigned char *msg32, const unsigned char *seckey, secp256k1_nonce_function noncefp, const void* noncedata) { secp256k1_scalar r, s; - secp256k1_scalar sec, non, msg; - int recid; - int ret = 0; - int overflow = 0; + int ret, recid; VERIFY_CHECK(ctx != NULL); ARG_CHECK(secp256k1_ecmult_gen_context_is_built(&ctx->ecmult_gen_ctx)); ARG_CHECK(msg32 != NULL); ARG_CHECK(signature != NULL); ARG_CHECK(seckey != NULL); - if (noncefp == NULL) { - noncefp = secp256k1_nonce_function_default; - } - secp256k1_scalar_set_b32(&sec, seckey, &overflow); - /* Fail if the secret key is invalid. */ - if (!overflow && !secp256k1_scalar_is_zero(&sec)) { - unsigned char nonce32[32]; - unsigned int count = 0; - secp256k1_scalar_set_b32(&msg, msg32, NULL); - while (1) { - ret = noncefp(nonce32, msg32, seckey, NULL, (void*)noncedata, count); - if (!ret) { - break; - } - secp256k1_scalar_set_b32(&non, nonce32, &overflow); - if (!secp256k1_scalar_is_zero(&non) && !overflow) { - if (secp256k1_ecdsa_sig_sign(&ctx->ecmult_gen_ctx, &r, &s, &sec, &msg, &non, &recid)) { - break; - } - } - count++; - } - memset(nonce32, 0, 32); - secp256k1_scalar_clear(&msg); - secp256k1_scalar_clear(&non); - secp256k1_scalar_clear(&sec); - } - if (ret) { - secp256k1_ecdsa_recoverable_signature_save(signature, &r, &s, recid); - } else { - memset(signature, 0, sizeof(*signature)); - } + ret = secp256k1_ecdsa_sign_inner(ctx, &r, &s, &recid, msg32, seckey, noncefp, noncedata); + secp256k1_ecdsa_recoverable_signature_save(signature, &r, &s, recid); return ret; } diff --git a/src/secp256k1/src/modules/recovery/tests_impl.h b/src/secp256k1/src/modules/recovery/tests_impl.h index 5c9bbe8610..38a533a755 100644 --- a/src/secp256k1/src/modules/recovery/tests_impl.h +++ b/src/secp256k1/src/modules/recovery/tests_impl.h @@ -215,7 +215,7 @@ void test_ecdsa_recovery_edge_cases(void) { }; const unsigned char sig64[64] = { /* Generated by signing the above message with nonce 'This is the nonce we will use...' - * and secret key 0 (which is not valid), resulting in recid 0. */ + * and secret key 0 (which is not valid), resulting in recid 1. */ 0x67, 0xCB, 0x28, 0x5F, 0x9C, 0xD1, 0x94, 0xE8, 0x40, 0xD6, 0x29, 0x39, 0x7A, 0xF5, 0x56, 0x96, 0x62, 0xFD, 0xE4, 0x46, 0x49, 0x99, 0x59, 0x63, diff --git a/src/secp256k1/src/scalar.h b/src/secp256k1/src/scalar.h index 59304cb66e..2a74703523 100644 --- a/src/secp256k1/src/scalar.h +++ b/src/secp256k1/src/scalar.h @@ -32,9 +32,17 @@ static unsigned int secp256k1_scalar_get_bits(const secp256k1_scalar *a, unsigne /** Access bits from a scalar. Not constant time. */ static unsigned int secp256k1_scalar_get_bits_var(const secp256k1_scalar *a, unsigned int offset, unsigned int count); -/** Set a scalar from a big endian byte array. */ +/** Set a scalar from a big endian byte array. The scalar will be reduced modulo group order `n`. + * In: bin: pointer to a 32-byte array. + * Out: r: scalar to be set. + * overflow: non-zero if the scalar was bigger or equal to `n` before reduction, zero otherwise (can be NULL). + */ static void secp256k1_scalar_set_b32(secp256k1_scalar *r, const unsigned char *bin, int *overflow); +/** Set a scalar from a big endian byte array and returns 1 if it is a valid + * seckey and 0 otherwise. */ +static int secp256k1_scalar_set_b32_seckey(secp256k1_scalar *r, const unsigned char *bin); + /** Set a scalar to an unsigned integer. */ static void secp256k1_scalar_set_int(secp256k1_scalar *r, unsigned int v); @@ -103,4 +111,7 @@ static void secp256k1_scalar_split_lambda(secp256k1_scalar *r1, secp256k1_scalar /** Multiply a and b (without taking the modulus!), divide by 2**shift, and round to the nearest integer. Shift must be at least 256. */ static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r, const secp256k1_scalar *a, const secp256k1_scalar *b, unsigned int shift); +/** If flag is true, set *r equal to *a; otherwise leave it. Constant-time. Both *r and *a must be initialized.*/ +static void secp256k1_scalar_cmov(secp256k1_scalar *r, const secp256k1_scalar *a, int flag); + #endif /* SECP256K1_SCALAR_H */ diff --git a/src/secp256k1/src/scalar_4x64_impl.h b/src/secp256k1/src/scalar_4x64_impl.h index d378335d99..8f539c4bc6 100644 --- a/src/secp256k1/src/scalar_4x64_impl.h +++ b/src/secp256k1/src/scalar_4x64_impl.h @@ -946,4 +946,15 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r, secp256k1_scalar_cadd_bit(r, 0, (l[(shift - 1) >> 6] >> ((shift - 1) & 0x3f)) & 1); } +static SECP256K1_INLINE void secp256k1_scalar_cmov(secp256k1_scalar *r, const secp256k1_scalar *a, int flag) { + uint64_t mask0, mask1; + VG_CHECK_VERIFY(r->d, sizeof(r->d)); + mask0 = flag + ~((uint64_t)0); + mask1 = ~mask0; + r->d[0] = (r->d[0] & mask0) | (a->d[0] & mask1); + r->d[1] = (r->d[1] & mask0) | (a->d[1] & mask1); + r->d[2] = (r->d[2] & mask0) | (a->d[2] & mask1); + r->d[3] = (r->d[3] & mask0) | (a->d[3] & mask1); +} + #endif /* SECP256K1_SCALAR_REPR_IMPL_H */ diff --git a/src/secp256k1/src/scalar_8x32_impl.h b/src/secp256k1/src/scalar_8x32_impl.h index 4f9ed61fea..3c372f34fe 100644 --- a/src/secp256k1/src/scalar_8x32_impl.h +++ b/src/secp256k1/src/scalar_8x32_impl.h @@ -718,4 +718,19 @@ SECP256K1_INLINE static void secp256k1_scalar_mul_shift_var(secp256k1_scalar *r, secp256k1_scalar_cadd_bit(r, 0, (l[(shift - 1) >> 5] >> ((shift - 1) & 0x1f)) & 1); } +static SECP256K1_INLINE void secp256k1_scalar_cmov(secp256k1_scalar *r, const secp256k1_scalar *a, int flag) { + uint32_t mask0, mask1; + VG_CHECK_VERIFY(r->d, sizeof(r->d)); + mask0 = flag + ~((uint32_t)0); + mask1 = ~mask0; + r->d[0] = (r->d[0] & mask0) | (a->d[0] & mask1); + r->d[1] = (r->d[1] & mask0) | (a->d[1] & mask1); + r->d[2] = (r->d[2] & mask0) | (a->d[2] & mask1); + r->d[3] = (r->d[3] & mask0) | (a->d[3] & mask1); + r->d[4] = (r->d[4] & mask0) | (a->d[4] & mask1); + r->d[5] = (r->d[5] & mask0) | (a->d[5] & mask1); + r->d[6] = (r->d[6] & mask0) | (a->d[6] & mask1); + r->d[7] = (r->d[7] & mask0) | (a->d[7] & mask1); +} + #endif /* SECP256K1_SCALAR_REPR_IMPL_H */ diff --git a/src/secp256k1/src/scalar_impl.h b/src/secp256k1/src/scalar_impl.h index fa790570ff..70cd73db06 100644 --- a/src/secp256k1/src/scalar_impl.h +++ b/src/secp256k1/src/scalar_impl.h @@ -7,8 +7,8 @@ #ifndef SECP256K1_SCALAR_IMPL_H #define SECP256K1_SCALAR_IMPL_H -#include "group.h" #include "scalar.h" +#include "util.h" #if defined HAVE_CONFIG_H #include "libsecp256k1-config.h" @@ -24,6 +24,9 @@ #error "Please select scalar implementation" #endif +static const secp256k1_scalar secp256k1_scalar_one = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 0, 0, 1); +static const secp256k1_scalar secp256k1_scalar_zero = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 0, 0, 0); + #ifndef USE_NUM_NONE static void secp256k1_scalar_get_num(secp256k1_num *r, const secp256k1_scalar *a) { unsigned char c[32]; @@ -52,6 +55,12 @@ static void secp256k1_scalar_order_get_num(secp256k1_num *r) { } #endif +static int secp256k1_scalar_set_b32_seckey(secp256k1_scalar *r, const unsigned char *bin) { + int overflow; + secp256k1_scalar_set_b32(r, bin, &overflow); + return (!overflow) & (!secp256k1_scalar_is_zero(r)); +} + static void secp256k1_scalar_inverse(secp256k1_scalar *r, const secp256k1_scalar *x) { #if defined(EXHAUSTIVE_TEST_ORDER) int i; diff --git a/src/secp256k1/src/scalar_low.h b/src/secp256k1/src/scalar_low.h index 5836febc5b..2794a7f171 100644 --- a/src/secp256k1/src/scalar_low.h +++ b/src/secp256k1/src/scalar_low.h @@ -12,4 +12,6 @@ /** A scalar modulo the group order of the secp256k1 curve. */ typedef uint32_t secp256k1_scalar; +#define SECP256K1_SCALAR_CONST(d7, d6, d5, d4, d3, d2, d1, d0) (d0) + #endif /* SECP256K1_SCALAR_REPR_H */ diff --git a/src/secp256k1/src/scalar_low_impl.h b/src/secp256k1/src/scalar_low_impl.h index c80e70c5a2..b79cf1ff6c 100644 --- a/src/secp256k1/src/scalar_low_impl.h +++ b/src/secp256k1/src/scalar_low_impl.h @@ -38,8 +38,11 @@ static int secp256k1_scalar_add(secp256k1_scalar *r, const secp256k1_scalar *a, static void secp256k1_scalar_cadd_bit(secp256k1_scalar *r, unsigned int bit, int flag) { if (flag && bit < 32) - *r += (1 << bit); + *r += ((uint32_t)1 << bit); #ifdef VERIFY + VERIFY_CHECK(bit < 32); + /* Verify that adding (1 << bit) will not overflow any in-range scalar *r by overflowing the underlying uint32_t. */ + VERIFY_CHECK(((uint32_t)1 << bit) - 1 <= UINT32_MAX - EXHAUSTIVE_TEST_ORDER); VERIFY_CHECK(secp256k1_scalar_check_overflow(r) == 0); #endif } @@ -111,4 +114,12 @@ SECP256K1_INLINE static int secp256k1_scalar_eq(const secp256k1_scalar *a, const return *a == *b; } +static SECP256K1_INLINE void secp256k1_scalar_cmov(secp256k1_scalar *r, const secp256k1_scalar *a, int flag) { + uint32_t mask0, mask1; + VG_CHECK_VERIFY(r, sizeof(*r)); + mask0 = flag + ~((uint32_t)0); + mask1 = ~mask0; + *r = (*r & mask0) | (*a & mask1); +} + #endif /* SECP256K1_SCALAR_REPR_IMPL_H */ diff --git a/src/secp256k1/src/scratch.h b/src/secp256k1/src/scratch.h index fef377af0d..77b35d126b 100644 --- a/src/secp256k1/src/scratch.h +++ b/src/secp256k1/src/scratch.h @@ -7,33 +7,36 @@ #ifndef _SECP256K1_SCRATCH_ #define _SECP256K1_SCRATCH_ -#define SECP256K1_SCRATCH_MAX_FRAMES 5 - /* The typedef is used internally; the struct name is used in the public API * (where it is exposed as a different typedef) */ typedef struct secp256k1_scratch_space_struct { - void *data[SECP256K1_SCRATCH_MAX_FRAMES]; - size_t offset[SECP256K1_SCRATCH_MAX_FRAMES]; - size_t frame_size[SECP256K1_SCRATCH_MAX_FRAMES]; - size_t frame; + /** guard against interpreting this object as other types */ + unsigned char magic[8]; + /** actual allocated data */ + void *data; + /** amount that has been allocated (i.e. `data + offset` is the next + * available pointer) */ + size_t alloc_size; + /** maximum size available to allocate */ size_t max_size; - const secp256k1_callback* error_callback; } secp256k1_scratch; static secp256k1_scratch* secp256k1_scratch_create(const secp256k1_callback* error_callback, size_t max_size); -static void secp256k1_scratch_destroy(secp256k1_scratch* scratch); +static void secp256k1_scratch_destroy(const secp256k1_callback* error_callback, secp256k1_scratch* scratch); -/** Attempts to allocate a new stack frame with `n` available bytes. Returns 1 on success, 0 on failure */ -static int secp256k1_scratch_allocate_frame(secp256k1_scratch* scratch, size_t n, size_t objects); +/** Returns an opaque object used to "checkpoint" a scratch space. Used + * with `secp256k1_scratch_apply_checkpoint` to undo allocations. */ +static size_t secp256k1_scratch_checkpoint(const secp256k1_callback* error_callback, const secp256k1_scratch* scratch); -/** Deallocates a stack frame */ -static void secp256k1_scratch_deallocate_frame(secp256k1_scratch* scratch); +/** Applies a check point received from `secp256k1_scratch_checkpoint`, + * undoing all allocations since that point. */ +static void secp256k1_scratch_apply_checkpoint(const secp256k1_callback* error_callback, secp256k1_scratch* scratch, size_t checkpoint); /** Returns the maximum allocation the scratch space will allow */ -static size_t secp256k1_scratch_max_allocation(const secp256k1_scratch* scratch, size_t n_objects); +static size_t secp256k1_scratch_max_allocation(const secp256k1_callback* error_callback, const secp256k1_scratch* scratch, size_t n_objects); /** Returns a pointer into the most recently allocated frame, or NULL if there is insufficient available space */ -static void *secp256k1_scratch_alloc(secp256k1_scratch* scratch, size_t n); +static void *secp256k1_scratch_alloc(const secp256k1_callback* error_callback, secp256k1_scratch* scratch, size_t n); #endif diff --git a/src/secp256k1/src/scratch_impl.h b/src/secp256k1/src/scratch_impl.h index abed713b21..4cee700001 100644 --- a/src/secp256k1/src/scratch_impl.h +++ b/src/secp256k1/src/scratch_impl.h @@ -7,78 +7,80 @@ #ifndef _SECP256K1_SCRATCH_IMPL_H_ #define _SECP256K1_SCRATCH_IMPL_H_ +#include "util.h" #include "scratch.h" -/* Using 16 bytes alignment because common architectures never have alignment - * requirements above 8 for any of the types we care about. In addition we - * leave some room because currently we don't care about a few bytes. - * TODO: Determine this at configure time. */ -#define ALIGNMENT 16 - -static secp256k1_scratch* secp256k1_scratch_create(const secp256k1_callback* error_callback, size_t max_size) { - secp256k1_scratch* ret = (secp256k1_scratch*)checked_malloc(error_callback, sizeof(*ret)); +static secp256k1_scratch* secp256k1_scratch_create(const secp256k1_callback* error_callback, size_t size) { + const size_t base_alloc = ((sizeof(secp256k1_scratch) + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + void *alloc = checked_malloc(error_callback, base_alloc + size); + secp256k1_scratch* ret = (secp256k1_scratch *)alloc; if (ret != NULL) { memset(ret, 0, sizeof(*ret)); - ret->max_size = max_size; - ret->error_callback = error_callback; + memcpy(ret->magic, "scratch", 8); + ret->data = (void *) ((char *) alloc + base_alloc); + ret->max_size = size; } return ret; } -static void secp256k1_scratch_destroy(secp256k1_scratch* scratch) { +static void secp256k1_scratch_destroy(const secp256k1_callback* error_callback, secp256k1_scratch* scratch) { if (scratch != NULL) { - VERIFY_CHECK(scratch->frame == 0); + VERIFY_CHECK(scratch->alloc_size == 0); /* all checkpoints should be applied */ + if (memcmp(scratch->magic, "scratch", 8) != 0) { + secp256k1_callback_call(error_callback, "invalid scratch space"); + return; + } + memset(scratch->magic, 0, sizeof(scratch->magic)); free(scratch); } } -static size_t secp256k1_scratch_max_allocation(const secp256k1_scratch* scratch, size_t objects) { - size_t i = 0; - size_t allocated = 0; - for (i = 0; i < scratch->frame; i++) { - allocated += scratch->frame_size[i]; - } - if (scratch->max_size - allocated <= objects * ALIGNMENT) { +static size_t secp256k1_scratch_checkpoint(const secp256k1_callback* error_callback, const secp256k1_scratch* scratch) { + if (memcmp(scratch->magic, "scratch", 8) != 0) { + secp256k1_callback_call(error_callback, "invalid scratch space"); return 0; } - return scratch->max_size - allocated - objects * ALIGNMENT; + return scratch->alloc_size; } -static int secp256k1_scratch_allocate_frame(secp256k1_scratch* scratch, size_t n, size_t objects) { - VERIFY_CHECK(scratch->frame < SECP256K1_SCRATCH_MAX_FRAMES); - - if (n <= secp256k1_scratch_max_allocation(scratch, objects)) { - n += objects * ALIGNMENT; - scratch->data[scratch->frame] = checked_malloc(scratch->error_callback, n); - if (scratch->data[scratch->frame] == NULL) { - return 0; - } - scratch->frame_size[scratch->frame] = n; - scratch->offset[scratch->frame] = 0; - scratch->frame++; - return 1; - } else { - return 0; +static void secp256k1_scratch_apply_checkpoint(const secp256k1_callback* error_callback, secp256k1_scratch* scratch, size_t checkpoint) { + if (memcmp(scratch->magic, "scratch", 8) != 0) { + secp256k1_callback_call(error_callback, "invalid scratch space"); + return; + } + if (checkpoint > scratch->alloc_size) { + secp256k1_callback_call(error_callback, "invalid checkpoint"); + return; } + scratch->alloc_size = checkpoint; } -static void secp256k1_scratch_deallocate_frame(secp256k1_scratch* scratch) { - VERIFY_CHECK(scratch->frame > 0); - scratch->frame -= 1; - free(scratch->data[scratch->frame]); +static size_t secp256k1_scratch_max_allocation(const secp256k1_callback* error_callback, const secp256k1_scratch* scratch, size_t objects) { + if (memcmp(scratch->magic, "scratch", 8) != 0) { + secp256k1_callback_call(error_callback, "invalid scratch space"); + return 0; + } + if (scratch->max_size - scratch->alloc_size <= objects * (ALIGNMENT - 1)) { + return 0; + } + return scratch->max_size - scratch->alloc_size - objects * (ALIGNMENT - 1); } -static void *secp256k1_scratch_alloc(secp256k1_scratch* scratch, size_t size) { +static void *secp256k1_scratch_alloc(const secp256k1_callback* error_callback, secp256k1_scratch* scratch, size_t size) { void *ret; - size_t frame = scratch->frame - 1; - size = ((size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + size = ROUND_TO_ALIGN(size); + + if (memcmp(scratch->magic, "scratch", 8) != 0) { + secp256k1_callback_call(error_callback, "invalid scratch space"); + return NULL; + } - if (scratch->frame == 0 || size + scratch->offset[frame] > scratch->frame_size[frame]) { + if (size > scratch->max_size - scratch->alloc_size) { return NULL; } - ret = (void *) ((unsigned char *) scratch->data[frame] + scratch->offset[frame]); + ret = (void *) ((char *) scratch->data + scratch->alloc_size); memset(ret, 0, size); - scratch->offset[frame] += size; + scratch->alloc_size += size; return ret; } diff --git a/src/secp256k1/src/secp256k1.c b/src/secp256k1/src/secp256k1.c index 15981f46e2..b03a6e6345 100644 --- a/src/secp256k1/src/secp256k1.c +++ b/src/secp256k1/src/secp256k1.c @@ -5,6 +5,7 @@ **********************************************************************/ #include "include/secp256k1.h" +#include "include/secp256k1_preallocated.h" #include "util.h" #include "num_impl.h" @@ -19,6 +20,10 @@ #include "hash_impl.h" #include "scratch_impl.h" +#if defined(VALGRIND) +# include <valgrind/memcheck.h> +#endif + #define ARG_CHECK(cond) do { \ if (EXPECT(!(cond), 0)) { \ secp256k1_callback_call(&ctx->illegal_callback, #cond); \ @@ -26,53 +31,101 @@ } \ } while(0) -static void default_illegal_callback_fn(const char* str, void* data) { +#define ARG_CHECK_NO_RETURN(cond) do { \ + if (EXPECT(!(cond), 0)) { \ + secp256k1_callback_call(&ctx->illegal_callback, #cond); \ + } \ +} while(0) + +#ifndef USE_EXTERNAL_DEFAULT_CALLBACKS +#include <stdlib.h> +#include <stdio.h> +static void secp256k1_default_illegal_callback_fn(const char* str, void* data) { (void)data; fprintf(stderr, "[libsecp256k1] illegal argument: %s\n", str); abort(); } - -static const secp256k1_callback default_illegal_callback = { - default_illegal_callback_fn, - NULL -}; - -static void default_error_callback_fn(const char* str, void* data) { +static void secp256k1_default_error_callback_fn(const char* str, void* data) { (void)data; fprintf(stderr, "[libsecp256k1] internal consistency check failed: %s\n", str); abort(); } +#else +void secp256k1_default_illegal_callback_fn(const char* str, void* data); +void secp256k1_default_error_callback_fn(const char* str, void* data); +#endif -static const secp256k1_callback default_error_callback = { - default_error_callback_fn, +static const secp256k1_callback default_illegal_callback = { + secp256k1_default_illegal_callback_fn, NULL }; +static const secp256k1_callback default_error_callback = { + secp256k1_default_error_callback_fn, + NULL +}; struct secp256k1_context_struct { secp256k1_ecmult_context ecmult_ctx; secp256k1_ecmult_gen_context ecmult_gen_ctx; secp256k1_callback illegal_callback; secp256k1_callback error_callback; + int declassify; }; static const secp256k1_context secp256k1_context_no_precomp_ = { { 0 }, { 0 }, - { default_illegal_callback_fn, 0 }, - { default_error_callback_fn, 0 } + { secp256k1_default_illegal_callback_fn, 0 }, + { secp256k1_default_error_callback_fn, 0 }, + 0 }; const secp256k1_context *secp256k1_context_no_precomp = &secp256k1_context_no_precomp_; -secp256k1_context* secp256k1_context_create(unsigned int flags) { - secp256k1_context* ret = (secp256k1_context*)checked_malloc(&default_error_callback, sizeof(secp256k1_context)); +size_t secp256k1_context_preallocated_size(unsigned int flags) { + size_t ret = ROUND_TO_ALIGN(sizeof(secp256k1_context)); + + if (EXPECT((flags & SECP256K1_FLAGS_TYPE_MASK) != SECP256K1_FLAGS_TYPE_CONTEXT, 0)) { + secp256k1_callback_call(&default_illegal_callback, + "Invalid flags"); + return 0; + } + + if (flags & SECP256K1_FLAGS_BIT_CONTEXT_SIGN) { + ret += SECP256K1_ECMULT_GEN_CONTEXT_PREALLOCATED_SIZE; + } + if (flags & SECP256K1_FLAGS_BIT_CONTEXT_VERIFY) { + ret += SECP256K1_ECMULT_CONTEXT_PREALLOCATED_SIZE; + } + return ret; +} + +size_t secp256k1_context_preallocated_clone_size(const secp256k1_context* ctx) { + size_t ret = ROUND_TO_ALIGN(sizeof(secp256k1_context)); + VERIFY_CHECK(ctx != NULL); + if (secp256k1_ecmult_gen_context_is_built(&ctx->ecmult_gen_ctx)) { + ret += SECP256K1_ECMULT_GEN_CONTEXT_PREALLOCATED_SIZE; + } + if (secp256k1_ecmult_context_is_built(&ctx->ecmult_ctx)) { + ret += SECP256K1_ECMULT_CONTEXT_PREALLOCATED_SIZE; + } + return ret; +} + +secp256k1_context* secp256k1_context_preallocated_create(void* prealloc, unsigned int flags) { + void* const base = prealloc; + size_t prealloc_size; + secp256k1_context* ret; + + VERIFY_CHECK(prealloc != NULL); + prealloc_size = secp256k1_context_preallocated_size(flags); + ret = (secp256k1_context*)manual_alloc(&prealloc, sizeof(secp256k1_context), base, prealloc_size); ret->illegal_callback = default_illegal_callback; ret->error_callback = default_error_callback; if (EXPECT((flags & SECP256K1_FLAGS_TYPE_MASK) != SECP256K1_FLAGS_TYPE_CONTEXT, 0)) { secp256k1_callback_call(&ret->illegal_callback, "Invalid flags"); - free(ret); return NULL; } @@ -80,47 +133,80 @@ secp256k1_context* secp256k1_context_create(unsigned int flags) { secp256k1_ecmult_gen_context_init(&ret->ecmult_gen_ctx); if (flags & SECP256K1_FLAGS_BIT_CONTEXT_SIGN) { - secp256k1_ecmult_gen_context_build(&ret->ecmult_gen_ctx, &ret->error_callback); + secp256k1_ecmult_gen_context_build(&ret->ecmult_gen_ctx, &prealloc); } if (flags & SECP256K1_FLAGS_BIT_CONTEXT_VERIFY) { - secp256k1_ecmult_context_build(&ret->ecmult_ctx, &ret->error_callback); + secp256k1_ecmult_context_build(&ret->ecmult_ctx, &prealloc); + } + ret->declassify = !!(flags & SECP256K1_FLAGS_BIT_CONTEXT_DECLASSIFY); + + return (secp256k1_context*) ret; +} + +secp256k1_context* secp256k1_context_create(unsigned int flags) { + size_t const prealloc_size = secp256k1_context_preallocated_size(flags); + secp256k1_context* ctx = (secp256k1_context*)checked_malloc(&default_error_callback, prealloc_size); + if (EXPECT(secp256k1_context_preallocated_create(ctx, flags) == NULL, 0)) { + free(ctx); + return NULL; } + return ctx; +} + +secp256k1_context* secp256k1_context_preallocated_clone(const secp256k1_context* ctx, void* prealloc) { + size_t prealloc_size; + secp256k1_context* ret; + VERIFY_CHECK(ctx != NULL); + ARG_CHECK(prealloc != NULL); + + prealloc_size = secp256k1_context_preallocated_clone_size(ctx); + ret = (secp256k1_context*)prealloc; + memcpy(ret, ctx, prealloc_size); + secp256k1_ecmult_gen_context_finalize_memcpy(&ret->ecmult_gen_ctx, &ctx->ecmult_gen_ctx); + secp256k1_ecmult_context_finalize_memcpy(&ret->ecmult_ctx, &ctx->ecmult_ctx); return ret; } secp256k1_context* secp256k1_context_clone(const secp256k1_context* ctx) { - secp256k1_context* ret = (secp256k1_context*)checked_malloc(&ctx->error_callback, sizeof(secp256k1_context)); - ret->illegal_callback = ctx->illegal_callback; - ret->error_callback = ctx->error_callback; - secp256k1_ecmult_context_clone(&ret->ecmult_ctx, &ctx->ecmult_ctx, &ctx->error_callback); - secp256k1_ecmult_gen_context_clone(&ret->ecmult_gen_ctx, &ctx->ecmult_gen_ctx, &ctx->error_callback); + secp256k1_context* ret; + size_t prealloc_size; + + VERIFY_CHECK(ctx != NULL); + prealloc_size = secp256k1_context_preallocated_clone_size(ctx); + ret = (secp256k1_context*)checked_malloc(&ctx->error_callback, prealloc_size); + ret = secp256k1_context_preallocated_clone(ctx, ret); return ret; } -void secp256k1_context_destroy(secp256k1_context* ctx) { - CHECK(ctx != secp256k1_context_no_precomp); +void secp256k1_context_preallocated_destroy(secp256k1_context* ctx) { + ARG_CHECK_NO_RETURN(ctx != secp256k1_context_no_precomp); if (ctx != NULL) { secp256k1_ecmult_context_clear(&ctx->ecmult_ctx); secp256k1_ecmult_gen_context_clear(&ctx->ecmult_gen_ctx); + } +} +void secp256k1_context_destroy(secp256k1_context* ctx) { + if (ctx != NULL) { + secp256k1_context_preallocated_destroy(ctx); free(ctx); } } void secp256k1_context_set_illegal_callback(secp256k1_context* ctx, void (*fun)(const char* message, void* data), const void* data) { - CHECK(ctx != secp256k1_context_no_precomp); + ARG_CHECK_NO_RETURN(ctx != secp256k1_context_no_precomp); if (fun == NULL) { - fun = default_illegal_callback_fn; + fun = secp256k1_default_illegal_callback_fn; } ctx->illegal_callback.fn = fun; ctx->illegal_callback.data = data; } void secp256k1_context_set_error_callback(secp256k1_context* ctx, void (*fun)(const char* message, void* data), const void* data) { - CHECK(ctx != secp256k1_context_no_precomp); + ARG_CHECK_NO_RETURN(ctx != secp256k1_context_no_precomp); if (fun == NULL) { - fun = default_error_callback_fn; + fun = secp256k1_default_error_callback_fn; } ctx->error_callback.fn = fun; ctx->error_callback.data = data; @@ -131,8 +217,23 @@ secp256k1_scratch_space* secp256k1_scratch_space_create(const secp256k1_context* return secp256k1_scratch_create(&ctx->error_callback, max_size); } -void secp256k1_scratch_space_destroy(secp256k1_scratch_space* scratch) { - secp256k1_scratch_destroy(scratch); +void secp256k1_scratch_space_destroy(const secp256k1_context *ctx, secp256k1_scratch_space* scratch) { + VERIFY_CHECK(ctx != NULL); + secp256k1_scratch_destroy(&ctx->error_callback, scratch); +} + +/* Mark memory as no-longer-secret for the purpose of analysing constant-time behaviour + * of the software. This is setup for use with valgrind but could be substituted with + * the appropriate instrumentation for other analysis tools. + */ +static SECP256K1_INLINE void secp256k1_declassify(const secp256k1_context* ctx, void *p, size_t len) { +#if defined(VALGRIND) + if (EXPECT(ctx->declassify,0)) VALGRIND_MAKE_MEM_DEFINED(p, len); +#else + (void)ctx; + (void)p; + (void)len; +#endif } static int secp256k1_pubkey_load(const secp256k1_context* ctx, secp256k1_ge* ge, const secp256k1_pubkey* pubkey) { @@ -366,61 +467,83 @@ static int nonce_function_rfc6979(unsigned char *nonce32, const unsigned char *m const secp256k1_nonce_function secp256k1_nonce_function_rfc6979 = nonce_function_rfc6979; const secp256k1_nonce_function secp256k1_nonce_function_default = nonce_function_rfc6979; -int secp256k1_ecdsa_sign(const secp256k1_context* ctx, secp256k1_ecdsa_signature *signature, const unsigned char *msg32, const unsigned char *seckey, secp256k1_nonce_function noncefp, const void* noncedata) { - secp256k1_scalar r, s; +static int secp256k1_ecdsa_sign_inner(const secp256k1_context* ctx, secp256k1_scalar* r, secp256k1_scalar* s, int* recid, const unsigned char *msg32, const unsigned char *seckey, secp256k1_nonce_function noncefp, const void* noncedata) { secp256k1_scalar sec, non, msg; int ret = 0; - int overflow = 0; - VERIFY_CHECK(ctx != NULL); - ARG_CHECK(secp256k1_ecmult_gen_context_is_built(&ctx->ecmult_gen_ctx)); - ARG_CHECK(msg32 != NULL); - ARG_CHECK(signature != NULL); - ARG_CHECK(seckey != NULL); + int is_sec_valid; + unsigned char nonce32[32]; + unsigned int count = 0; + /* Default initialization here is important so we won't pass uninit values to the cmov in the end */ + *r = secp256k1_scalar_zero; + *s = secp256k1_scalar_zero; + if (recid) { + *recid = 0; + } if (noncefp == NULL) { noncefp = secp256k1_nonce_function_default; } - secp256k1_scalar_set_b32(&sec, seckey, &overflow); /* Fail if the secret key is invalid. */ - if (!overflow && !secp256k1_scalar_is_zero(&sec)) { - unsigned char nonce32[32]; - unsigned int count = 0; - secp256k1_scalar_set_b32(&msg, msg32, NULL); - while (1) { - ret = noncefp(nonce32, msg32, seckey, NULL, (void*)noncedata, count); - if (!ret) { + is_sec_valid = secp256k1_scalar_set_b32_seckey(&sec, seckey); + secp256k1_scalar_cmov(&sec, &secp256k1_scalar_one, !is_sec_valid); + secp256k1_scalar_set_b32(&msg, msg32, NULL); + while (1) { + int is_nonce_valid; + ret = !!noncefp(nonce32, msg32, seckey, NULL, (void*)noncedata, count); + if (!ret) { + break; + } + is_nonce_valid = secp256k1_scalar_set_b32_seckey(&non, nonce32); + /* The nonce is still secret here, but it being invalid is is less likely than 1:2^255. */ + secp256k1_declassify(ctx, &is_nonce_valid, sizeof(is_nonce_valid)); + if (is_nonce_valid) { + ret = secp256k1_ecdsa_sig_sign(&ctx->ecmult_gen_ctx, r, s, &sec, &msg, &non, recid); + /* The final signature is no longer a secret, nor is the fact that we were successful or not. */ + secp256k1_declassify(ctx, &ret, sizeof(ret)); + if (ret) { break; } - secp256k1_scalar_set_b32(&non, nonce32, &overflow); - if (!overflow && !secp256k1_scalar_is_zero(&non)) { - if (secp256k1_ecdsa_sig_sign(&ctx->ecmult_gen_ctx, &r, &s, &sec, &msg, &non, NULL)) { - break; - } - } - count++; } - memset(nonce32, 0, 32); - secp256k1_scalar_clear(&msg); - secp256k1_scalar_clear(&non); - secp256k1_scalar_clear(&sec); + count++; } - if (ret) { - secp256k1_ecdsa_signature_save(signature, &r, &s); - } else { - memset(signature, 0, sizeof(*signature)); + /* We don't want to declassify is_sec_valid and therefore the range of + * seckey. As a result is_sec_valid is included in ret only after ret was + * used as a branching variable. */ + ret &= is_sec_valid; + memset(nonce32, 0, 32); + secp256k1_scalar_clear(&msg); + secp256k1_scalar_clear(&non); + secp256k1_scalar_clear(&sec); + secp256k1_scalar_cmov(r, &secp256k1_scalar_zero, !ret); + secp256k1_scalar_cmov(s, &secp256k1_scalar_zero, !ret); + if (recid) { + const int zero = 0; + secp256k1_int_cmov(recid, &zero, !ret); } return ret; } +int secp256k1_ecdsa_sign(const secp256k1_context* ctx, secp256k1_ecdsa_signature *signature, const unsigned char *msg32, const unsigned char *seckey, secp256k1_nonce_function noncefp, const void* noncedata) { + secp256k1_scalar r, s; + int ret; + VERIFY_CHECK(ctx != NULL); + ARG_CHECK(secp256k1_ecmult_gen_context_is_built(&ctx->ecmult_gen_ctx)); + ARG_CHECK(msg32 != NULL); + ARG_CHECK(signature != NULL); + ARG_CHECK(seckey != NULL); + + ret = secp256k1_ecdsa_sign_inner(ctx, &r, &s, NULL, msg32, seckey, noncefp, noncedata); + secp256k1_ecdsa_signature_save(signature, &r, &s); + return ret; +} + int secp256k1_ec_seckey_verify(const secp256k1_context* ctx, const unsigned char *seckey) { secp256k1_scalar sec; int ret; - int overflow; VERIFY_CHECK(ctx != NULL); ARG_CHECK(seckey != NULL); - secp256k1_scalar_set_b32(&sec, seckey, &overflow); - ret = !overflow && !secp256k1_scalar_is_zero(&sec); + ret = secp256k1_scalar_set_b32_seckey(&sec, seckey); secp256k1_scalar_clear(&sec); return ret; } @@ -429,7 +552,6 @@ int secp256k1_ec_pubkey_create(const secp256k1_context* ctx, secp256k1_pubkey *p secp256k1_gej pj; secp256k1_ge p; secp256k1_scalar sec; - int overflow; int ret = 0; VERIFY_CHECK(ctx != NULL); ARG_CHECK(pubkey != NULL); @@ -437,27 +559,35 @@ int secp256k1_ec_pubkey_create(const secp256k1_context* ctx, secp256k1_pubkey *p ARG_CHECK(secp256k1_ecmult_gen_context_is_built(&ctx->ecmult_gen_ctx)); ARG_CHECK(seckey != NULL); - secp256k1_scalar_set_b32(&sec, seckey, &overflow); - ret = (!overflow) & (!secp256k1_scalar_is_zero(&sec)); - if (ret) { - secp256k1_ecmult_gen(&ctx->ecmult_gen_ctx, &pj, &sec); - secp256k1_ge_set_gej(&p, &pj); - secp256k1_pubkey_save(pubkey, &p); - } + ret = secp256k1_scalar_set_b32_seckey(&sec, seckey); + secp256k1_scalar_cmov(&sec, &secp256k1_scalar_one, !ret); + + secp256k1_ecmult_gen(&ctx->ecmult_gen_ctx, &pj, &sec); + secp256k1_ge_set_gej(&p, &pj); + secp256k1_pubkey_save(pubkey, &p); + memczero(pubkey, sizeof(*pubkey), !ret); + secp256k1_scalar_clear(&sec); return ret; } -int secp256k1_ec_privkey_negate(const secp256k1_context* ctx, unsigned char *seckey) { +int secp256k1_ec_seckey_negate(const secp256k1_context* ctx, unsigned char *seckey) { secp256k1_scalar sec; + int ret = 0; VERIFY_CHECK(ctx != NULL); ARG_CHECK(seckey != NULL); - secp256k1_scalar_set_b32(&sec, seckey, NULL); + ret = secp256k1_scalar_set_b32_seckey(&sec, seckey); + secp256k1_scalar_cmov(&sec, &secp256k1_scalar_zero, !ret); secp256k1_scalar_negate(&sec, &sec); secp256k1_scalar_get_b32(seckey, &sec); - return 1; + secp256k1_scalar_clear(&sec); + return ret; +} + +int secp256k1_ec_privkey_negate(const secp256k1_context* ctx, unsigned char *seckey) { + return secp256k1_ec_seckey_negate(ctx, seckey); } int secp256k1_ec_pubkey_negate(const secp256k1_context* ctx, secp256k1_pubkey *pubkey) { @@ -475,7 +605,7 @@ int secp256k1_ec_pubkey_negate(const secp256k1_context* ctx, secp256k1_pubkey *p return ret; } -int secp256k1_ec_privkey_tweak_add(const secp256k1_context* ctx, unsigned char *seckey, const unsigned char *tweak) { +int secp256k1_ec_seckey_tweak_add(const secp256k1_context* ctx, unsigned char *seckey, const unsigned char *tweak) { secp256k1_scalar term; secp256k1_scalar sec; int ret = 0; @@ -485,19 +615,21 @@ int secp256k1_ec_privkey_tweak_add(const secp256k1_context* ctx, unsigned char * ARG_CHECK(tweak != NULL); secp256k1_scalar_set_b32(&term, tweak, &overflow); - secp256k1_scalar_set_b32(&sec, seckey, NULL); + ret = secp256k1_scalar_set_b32_seckey(&sec, seckey); - ret = !overflow && secp256k1_eckey_privkey_tweak_add(&sec, &term); - memset(seckey, 0, 32); - if (ret) { - secp256k1_scalar_get_b32(seckey, &sec); - } + ret &= (!overflow) & secp256k1_eckey_privkey_tweak_add(&sec, &term); + secp256k1_scalar_cmov(&sec, &secp256k1_scalar_zero, !ret); + secp256k1_scalar_get_b32(seckey, &sec); secp256k1_scalar_clear(&sec); secp256k1_scalar_clear(&term); return ret; } +int secp256k1_ec_privkey_tweak_add(const secp256k1_context* ctx, unsigned char *seckey, const unsigned char *tweak) { + return secp256k1_ec_seckey_tweak_add(ctx, seckey, tweak); +} + int secp256k1_ec_pubkey_tweak_add(const secp256k1_context* ctx, secp256k1_pubkey *pubkey, const unsigned char *tweak) { secp256k1_ge p; secp256k1_scalar term; @@ -522,7 +654,7 @@ int secp256k1_ec_pubkey_tweak_add(const secp256k1_context* ctx, secp256k1_pubkey return ret; } -int secp256k1_ec_privkey_tweak_mul(const secp256k1_context* ctx, unsigned char *seckey, const unsigned char *tweak) { +int secp256k1_ec_seckey_tweak_mul(const secp256k1_context* ctx, unsigned char *seckey, const unsigned char *tweak) { secp256k1_scalar factor; secp256k1_scalar sec; int ret = 0; @@ -532,18 +664,20 @@ int secp256k1_ec_privkey_tweak_mul(const secp256k1_context* ctx, unsigned char * ARG_CHECK(tweak != NULL); secp256k1_scalar_set_b32(&factor, tweak, &overflow); - secp256k1_scalar_set_b32(&sec, seckey, NULL); - ret = !overflow && secp256k1_eckey_privkey_tweak_mul(&sec, &factor); - memset(seckey, 0, 32); - if (ret) { - secp256k1_scalar_get_b32(seckey, &sec); - } + ret = secp256k1_scalar_set_b32_seckey(&sec, seckey); + ret &= (!overflow) & secp256k1_eckey_privkey_tweak_mul(&sec, &factor); + secp256k1_scalar_cmov(&sec, &secp256k1_scalar_zero, !ret); + secp256k1_scalar_get_b32(seckey, &sec); secp256k1_scalar_clear(&sec); secp256k1_scalar_clear(&factor); return ret; } +int secp256k1_ec_privkey_tweak_mul(const secp256k1_context* ctx, unsigned char *seckey, const unsigned char *tweak) { + return secp256k1_ec_seckey_tweak_mul(ctx, seckey, tweak); +} + int secp256k1_ec_pubkey_tweak_mul(const secp256k1_context* ctx, secp256k1_pubkey *pubkey, const unsigned char *tweak) { secp256k1_ge p; secp256k1_scalar factor; diff --git a/src/secp256k1/src/tests.c b/src/secp256k1/src/tests.c index f1c4db929a..374ed7dc12 100644 --- a/src/secp256k1/src/tests.c +++ b/src/secp256k1/src/tests.c @@ -16,6 +16,7 @@ #include "secp256k1.c" #include "include/secp256k1.h" +#include "include/secp256k1_preallocated.h" #include "testrand_impl.h" #ifdef ENABLE_OPENSSL_TESTS @@ -31,17 +32,6 @@ void ECDSA_SIG_get0(const ECDSA_SIG *sig, const BIGNUM **pr, const BIGNUM **ps) #include "contrib/lax_der_parsing.c" #include "contrib/lax_der_privatekey_parsing.c" -#if !defined(VG_CHECK) -# if defined(VALGRIND) -# include <valgrind/memcheck.h> -# define VG_UNDEF(x,y) VALGRIND_MAKE_MEM_UNDEFINED((x),(y)) -# define VG_CHECK(x,y) VALGRIND_CHECK_MEM_IS_DEFINED((x),(y)) -# else -# define VG_UNDEF(x,y) -# define VG_CHECK(x,y) -# endif -#endif - static int count = 64; static secp256k1_context *ctx = NULL; @@ -82,7 +72,9 @@ void random_field_element_magnitude(secp256k1_fe *fe) { secp256k1_fe_negate(&zero, &zero, 0); secp256k1_fe_mul_int(&zero, n - 1); secp256k1_fe_add(fe, &zero); - VERIFY_CHECK(fe->magnitude == n); +#ifdef VERIFY + CHECK(fe->magnitude == n); +#endif } void random_group_element_test(secp256k1_ge *ge) { @@ -137,23 +129,53 @@ void random_scalar_order(secp256k1_scalar *num) { } while(1); } -void run_context_tests(void) { +void random_scalar_order_b32(unsigned char *b32) { + secp256k1_scalar num; + random_scalar_order(&num); + secp256k1_scalar_get_b32(b32, &num); +} + +void run_context_tests(int use_prealloc) { secp256k1_pubkey pubkey; secp256k1_pubkey zero_pubkey; secp256k1_ecdsa_signature sig; unsigned char ctmp[32]; int32_t ecount; int32_t ecount2; - secp256k1_context *none = secp256k1_context_create(SECP256K1_CONTEXT_NONE); - secp256k1_context *sign = secp256k1_context_create(SECP256K1_CONTEXT_SIGN); - secp256k1_context *vrfy = secp256k1_context_create(SECP256K1_CONTEXT_VERIFY); - secp256k1_context *both = secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); + secp256k1_context *none; + secp256k1_context *sign; + secp256k1_context *vrfy; + secp256k1_context *both; + void *none_prealloc = NULL; + void *sign_prealloc = NULL; + void *vrfy_prealloc = NULL; + void *both_prealloc = NULL; secp256k1_gej pubj; secp256k1_ge pub; secp256k1_scalar msg, key, nonce; secp256k1_scalar sigr, sigs; + if (use_prealloc) { + none_prealloc = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_NONE)); + sign_prealloc = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_SIGN)); + vrfy_prealloc = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_VERIFY)); + both_prealloc = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY)); + CHECK(none_prealloc != NULL); + CHECK(sign_prealloc != NULL); + CHECK(vrfy_prealloc != NULL); + CHECK(both_prealloc != NULL); + none = secp256k1_context_preallocated_create(none_prealloc, SECP256K1_CONTEXT_NONE); + sign = secp256k1_context_preallocated_create(sign_prealloc, SECP256K1_CONTEXT_SIGN); + vrfy = secp256k1_context_preallocated_create(vrfy_prealloc, SECP256K1_CONTEXT_VERIFY); + both = secp256k1_context_preallocated_create(both_prealloc, SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); + } else { + none = secp256k1_context_create(SECP256K1_CONTEXT_NONE); + sign = secp256k1_context_create(SECP256K1_CONTEXT_SIGN); + vrfy = secp256k1_context_create(SECP256K1_CONTEXT_VERIFY); + both = secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); + } + memset(&zero_pubkey, 0, sizeof(zero_pubkey)); ecount = 0; @@ -163,14 +185,57 @@ void run_context_tests(void) { secp256k1_context_set_error_callback(sign, counting_illegal_callback_fn, NULL); CHECK(vrfy->error_callback.fn != sign->error_callback.fn); + /* check if sizes for cloning are consistent */ + CHECK(secp256k1_context_preallocated_clone_size(none) == secp256k1_context_preallocated_size(SECP256K1_CONTEXT_NONE)); + CHECK(secp256k1_context_preallocated_clone_size(sign) == secp256k1_context_preallocated_size(SECP256K1_CONTEXT_SIGN)); + CHECK(secp256k1_context_preallocated_clone_size(vrfy) == secp256k1_context_preallocated_size(SECP256K1_CONTEXT_VERIFY)); + CHECK(secp256k1_context_preallocated_clone_size(both) == secp256k1_context_preallocated_size(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY)); + /*** clone and destroy all of them to make sure cloning was complete ***/ { secp256k1_context *ctx_tmp; - ctx_tmp = none; none = secp256k1_context_clone(none); secp256k1_context_destroy(ctx_tmp); - ctx_tmp = sign; sign = secp256k1_context_clone(sign); secp256k1_context_destroy(ctx_tmp); - ctx_tmp = vrfy; vrfy = secp256k1_context_clone(vrfy); secp256k1_context_destroy(ctx_tmp); - ctx_tmp = both; both = secp256k1_context_clone(both); secp256k1_context_destroy(ctx_tmp); + if (use_prealloc) { + /* clone into a non-preallocated context and then again into a new preallocated one. */ + ctx_tmp = none; none = secp256k1_context_clone(none); secp256k1_context_preallocated_destroy(ctx_tmp); + free(none_prealloc); none_prealloc = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_NONE)); CHECK(none_prealloc != NULL); + ctx_tmp = none; none = secp256k1_context_preallocated_clone(none, none_prealloc); secp256k1_context_destroy(ctx_tmp); + + ctx_tmp = sign; sign = secp256k1_context_clone(sign); secp256k1_context_preallocated_destroy(ctx_tmp); + free(sign_prealloc); sign_prealloc = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_SIGN)); CHECK(sign_prealloc != NULL); + ctx_tmp = sign; sign = secp256k1_context_preallocated_clone(sign, sign_prealloc); secp256k1_context_destroy(ctx_tmp); + + ctx_tmp = vrfy; vrfy = secp256k1_context_clone(vrfy); secp256k1_context_preallocated_destroy(ctx_tmp); + free(vrfy_prealloc); vrfy_prealloc = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_VERIFY)); CHECK(vrfy_prealloc != NULL); + ctx_tmp = vrfy; vrfy = secp256k1_context_preallocated_clone(vrfy, vrfy_prealloc); secp256k1_context_destroy(ctx_tmp); + + ctx_tmp = both; both = secp256k1_context_clone(both); secp256k1_context_preallocated_destroy(ctx_tmp); + free(both_prealloc); both_prealloc = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY)); CHECK(both_prealloc != NULL); + ctx_tmp = both; both = secp256k1_context_preallocated_clone(both, both_prealloc); secp256k1_context_destroy(ctx_tmp); + } else { + /* clone into a preallocated context and then again into a new non-preallocated one. */ + void *prealloc_tmp; + + prealloc_tmp = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_NONE)); CHECK(prealloc_tmp != NULL); + ctx_tmp = none; none = secp256k1_context_preallocated_clone(none, prealloc_tmp); secp256k1_context_destroy(ctx_tmp); + ctx_tmp = none; none = secp256k1_context_clone(none); secp256k1_context_preallocated_destroy(ctx_tmp); + free(prealloc_tmp); + + prealloc_tmp = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_SIGN)); CHECK(prealloc_tmp != NULL); + ctx_tmp = sign; sign = secp256k1_context_preallocated_clone(sign, prealloc_tmp); secp256k1_context_destroy(ctx_tmp); + ctx_tmp = sign; sign = secp256k1_context_clone(sign); secp256k1_context_preallocated_destroy(ctx_tmp); + free(prealloc_tmp); + + prealloc_tmp = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_VERIFY)); CHECK(prealloc_tmp != NULL); + ctx_tmp = vrfy; vrfy = secp256k1_context_preallocated_clone(vrfy, prealloc_tmp); secp256k1_context_destroy(ctx_tmp); + ctx_tmp = vrfy; vrfy = secp256k1_context_clone(vrfy); secp256k1_context_preallocated_destroy(ctx_tmp); + free(prealloc_tmp); + + prealloc_tmp = malloc(secp256k1_context_preallocated_size(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY)); CHECK(prealloc_tmp != NULL); + ctx_tmp = both; both = secp256k1_context_preallocated_clone(both, prealloc_tmp); secp256k1_context_destroy(ctx_tmp); + ctx_tmp = both; both = secp256k1_context_clone(both); secp256k1_context_preallocated_destroy(ctx_tmp); + free(prealloc_tmp); + } } /* Verify that the error callback makes it across the clone. */ @@ -229,10 +294,6 @@ void run_context_tests(void) { secp256k1_context_set_illegal_callback(vrfy, NULL, NULL); secp256k1_context_set_illegal_callback(sign, NULL, NULL); - /* This shouldn't leak memory, due to already-set tests. */ - secp256k1_ecmult_gen_context_build(&sign->ecmult_gen_ctx, NULL); - secp256k1_ecmult_context_build(&vrfy->ecmult_ctx, NULL); - /* obtain a working nonce */ do { random_scalar_order_test(&nonce); @@ -247,49 +308,95 @@ void run_context_tests(void) { CHECK(secp256k1_ecdsa_sig_verify(&both->ecmult_ctx, &sigr, &sigs, &pub, &msg)); /* cleanup */ - secp256k1_context_destroy(none); - secp256k1_context_destroy(sign); - secp256k1_context_destroy(vrfy); - secp256k1_context_destroy(both); + if (use_prealloc) { + secp256k1_context_preallocated_destroy(none); + secp256k1_context_preallocated_destroy(sign); + secp256k1_context_preallocated_destroy(vrfy); + secp256k1_context_preallocated_destroy(both); + free(none_prealloc); + free(sign_prealloc); + free(vrfy_prealloc); + free(both_prealloc); + } else { + secp256k1_context_destroy(none); + secp256k1_context_destroy(sign); + secp256k1_context_destroy(vrfy); + secp256k1_context_destroy(both); + } /* Defined as no-op. */ secp256k1_context_destroy(NULL); + secp256k1_context_preallocated_destroy(NULL); + } void run_scratch_tests(void) { + const size_t adj_alloc = ((500 + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT; + int32_t ecount = 0; + size_t checkpoint; + size_t checkpoint_2; secp256k1_context *none = secp256k1_context_create(SECP256K1_CONTEXT_NONE); secp256k1_scratch_space *scratch; + secp256k1_scratch_space local_scratch; /* Test public API */ secp256k1_context_set_illegal_callback(none, counting_illegal_callback_fn, &ecount); + secp256k1_context_set_error_callback(none, counting_illegal_callback_fn, &ecount); scratch = secp256k1_scratch_space_create(none, 1000); CHECK(scratch != NULL); CHECK(ecount == 0); /* Test internal API */ - CHECK(secp256k1_scratch_max_allocation(scratch, 0) == 1000); - CHECK(secp256k1_scratch_max_allocation(scratch, 1) < 1000); - - /* Allocating 500 bytes with no frame fails */ - CHECK(secp256k1_scratch_alloc(scratch, 500) == NULL); - CHECK(secp256k1_scratch_max_allocation(scratch, 0) == 1000); - - /* ...but pushing a new stack frame does affect the max allocation */ - CHECK(secp256k1_scratch_allocate_frame(scratch, 500, 1 == 1)); - CHECK(secp256k1_scratch_max_allocation(scratch, 1) < 500); /* 500 - ALIGNMENT */ - CHECK(secp256k1_scratch_alloc(scratch, 500) != NULL); - CHECK(secp256k1_scratch_alloc(scratch, 500) == NULL); - - CHECK(secp256k1_scratch_allocate_frame(scratch, 500, 1) == 0); + CHECK(secp256k1_scratch_max_allocation(&none->error_callback, scratch, 0) == 1000); + CHECK(secp256k1_scratch_max_allocation(&none->error_callback, scratch, 1) == 1000 - (ALIGNMENT - 1)); + CHECK(scratch->alloc_size == 0); + CHECK(scratch->alloc_size % ALIGNMENT == 0); + + /* Allocating 500 bytes succeeds */ + checkpoint = secp256k1_scratch_checkpoint(&none->error_callback, scratch); + CHECK(secp256k1_scratch_alloc(&none->error_callback, scratch, 500) != NULL); + CHECK(secp256k1_scratch_max_allocation(&none->error_callback, scratch, 0) == 1000 - adj_alloc); + CHECK(secp256k1_scratch_max_allocation(&none->error_callback, scratch, 1) == 1000 - adj_alloc - (ALIGNMENT - 1)); + CHECK(scratch->alloc_size != 0); + CHECK(scratch->alloc_size % ALIGNMENT == 0); + + /* Allocating another 500 bytes fails */ + CHECK(secp256k1_scratch_alloc(&none->error_callback, scratch, 500) == NULL); + CHECK(secp256k1_scratch_max_allocation(&none->error_callback, scratch, 0) == 1000 - adj_alloc); + CHECK(secp256k1_scratch_max_allocation(&none->error_callback, scratch, 1) == 1000 - adj_alloc - (ALIGNMENT - 1)); + CHECK(scratch->alloc_size != 0); + CHECK(scratch->alloc_size % ALIGNMENT == 0); + + /* ...but it succeeds once we apply the checkpoint to undo it */ + secp256k1_scratch_apply_checkpoint(&none->error_callback, scratch, checkpoint); + CHECK(scratch->alloc_size == 0); + CHECK(secp256k1_scratch_max_allocation(&none->error_callback, scratch, 0) == 1000); + CHECK(secp256k1_scratch_alloc(&none->error_callback, scratch, 500) != NULL); + CHECK(scratch->alloc_size != 0); + + /* try to apply a bad checkpoint */ + checkpoint_2 = secp256k1_scratch_checkpoint(&none->error_callback, scratch); + secp256k1_scratch_apply_checkpoint(&none->error_callback, scratch, checkpoint); + CHECK(ecount == 0); + secp256k1_scratch_apply_checkpoint(&none->error_callback, scratch, checkpoint_2); /* checkpoint_2 is after checkpoint */ + CHECK(ecount == 1); + secp256k1_scratch_apply_checkpoint(&none->error_callback, scratch, (size_t) -1); /* this is just wildly invalid */ + CHECK(ecount == 2); - /* ...and this effect is undone by popping the frame */ - secp256k1_scratch_deallocate_frame(scratch); - CHECK(secp256k1_scratch_max_allocation(scratch, 0) == 1000); - CHECK(secp256k1_scratch_alloc(scratch, 500) == NULL); + /* try to use badly initialized scratch space */ + secp256k1_scratch_space_destroy(none, scratch); + memset(&local_scratch, 0, sizeof(local_scratch)); + scratch = &local_scratch; + CHECK(!secp256k1_scratch_max_allocation(&none->error_callback, scratch, 0)); + CHECK(ecount == 3); + CHECK(secp256k1_scratch_alloc(&none->error_callback, scratch, 500) == NULL); + CHECK(ecount == 4); + secp256k1_scratch_space_destroy(none, scratch); + CHECK(ecount == 5); /* cleanup */ - secp256k1_scratch_space_destroy(scratch); + secp256k1_scratch_space_destroy(none, NULL); /* no-op */ secp256k1_context_destroy(none); } @@ -965,11 +1072,31 @@ void scalar_test(void) { } +void run_scalar_set_b32_seckey_tests(void) { + unsigned char b32[32]; + secp256k1_scalar s1; + secp256k1_scalar s2; + + /* Usually set_b32 and set_b32_seckey give the same result */ + random_scalar_order_b32(b32); + secp256k1_scalar_set_b32(&s1, b32, NULL); + CHECK(secp256k1_scalar_set_b32_seckey(&s2, b32) == 1); + CHECK(secp256k1_scalar_eq(&s1, &s2) == 1); + + memset(b32, 0, sizeof(b32)); + CHECK(secp256k1_scalar_set_b32_seckey(&s2, b32) == 0); + memset(b32, 0xFF, sizeof(b32)); + CHECK(secp256k1_scalar_set_b32_seckey(&s2, b32) == 0); +} + void run_scalar_tests(void) { int i; for (i = 0; i < 128 * count; i++) { scalar_test(); } + for (i = 0; i < count; i++) { + run_scalar_set_b32_seckey_tests(); + } { /* (-1)+1 should be zero. */ @@ -985,16 +1112,43 @@ void run_scalar_tests(void) { #ifndef USE_NUM_NONE { - /* A scalar with value of the curve order should be 0. */ + /* Test secp256k1_scalar_set_b32 boundary conditions */ secp256k1_num order; - secp256k1_scalar zero; + secp256k1_scalar scalar; unsigned char bin[32]; + unsigned char bin_tmp[32]; int overflow = 0; + /* 2^256-1 - order */ + static const secp256k1_scalar all_ones_minus_order = SECP256K1_SCALAR_CONST( + 0x00000000UL, 0x00000000UL, 0x00000000UL, 0x00000001UL, + 0x45512319UL, 0x50B75FC4UL, 0x402DA173UL, 0x2FC9BEBEUL + ); + + /* A scalar set to 0s should be 0. */ + memset(bin, 0, 32); + secp256k1_scalar_set_b32(&scalar, bin, &overflow); + CHECK(overflow == 0); + CHECK(secp256k1_scalar_is_zero(&scalar)); + + /* A scalar with value of the curve order should be 0. */ secp256k1_scalar_order_get_num(&order); secp256k1_num_get_bin(bin, 32, &order); - secp256k1_scalar_set_b32(&zero, bin, &overflow); + secp256k1_scalar_set_b32(&scalar, bin, &overflow); + CHECK(overflow == 1); + CHECK(secp256k1_scalar_is_zero(&scalar)); + + /* A scalar with value of the curve order minus one should not overflow. */ + bin[31] -= 1; + secp256k1_scalar_set_b32(&scalar, bin, &overflow); + CHECK(overflow == 0); + secp256k1_scalar_get_b32(bin_tmp, &scalar); + CHECK(memcmp(bin, bin_tmp, 32) == 0); + + /* A scalar set to all 1s should overflow. */ + memset(bin, 0xFF, 32); + secp256k1_scalar_set_b32(&scalar, bin, &overflow); CHECK(overflow == 1); - CHECK(secp256k1_scalar_is_zero(&zero)); + CHECK(secp256k1_scalar_eq(&scalar, &all_ones_minus_order)); } #endif @@ -1709,24 +1863,32 @@ void run_field_misc(void) { /* Test fe conditional move; z is not normalized here. */ q = x; secp256k1_fe_cmov(&x, &z, 0); - VERIFY_CHECK(!x.normalized && x.magnitude == z.magnitude); +#ifdef VERIFY + CHECK(x.normalized && x.magnitude == 1); +#endif secp256k1_fe_cmov(&x, &x, 1); CHECK(fe_memcmp(&x, &z) != 0); CHECK(fe_memcmp(&x, &q) == 0); secp256k1_fe_cmov(&q, &z, 1); - VERIFY_CHECK(!q.normalized && q.magnitude == z.magnitude); +#ifdef VERIFY + CHECK(!q.normalized && q.magnitude == z.magnitude); +#endif CHECK(fe_memcmp(&q, &z) == 0); secp256k1_fe_normalize_var(&x); secp256k1_fe_normalize_var(&z); CHECK(!secp256k1_fe_equal_var(&x, &z)); secp256k1_fe_normalize_var(&q); secp256k1_fe_cmov(&q, &z, (i&1)); - VERIFY_CHECK(q.normalized && q.magnitude == 1); +#ifdef VERIFY + CHECK(q.normalized && q.magnitude == 1); +#endif for (j = 0; j < 6; j++) { secp256k1_fe_negate(&z, &z, j+1); secp256k1_fe_normalize_var(&q); secp256k1_fe_cmov(&q, &z, (j&1)); - VERIFY_CHECK(!q.normalized && q.magnitude == (j+2)); +#ifdef VERIFY + CHECK((q.normalized != (j&1)) && q.magnitude == ((j&1) ? z.magnitude : 1)); +#endif } secp256k1_fe_normalize_var(&z); /* Test storage conversion and conditional moves. */ @@ -2120,7 +2282,7 @@ void test_ge(void) { /* Test batch gej -> ge conversion with many infinities. */ for (i = 0; i < 4 * runs + 1; i++) { random_group_element_test(&ge[i]); - /* randomly set half the points to infinitiy */ + /* randomly set half the points to infinity */ if(secp256k1_fe_is_odd(&ge[i].x)) { secp256k1_ge_set_infinity(&ge[i]); } @@ -2572,14 +2734,13 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e secp256k1_gej r; secp256k1_gej r2; ecmult_multi_data data; - secp256k1_scratch *scratch_empty; data.sc = sc; data.pt = pt; secp256k1_scalar_set_int(&szero, 0); /* No points to multiply */ - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, NULL, ecmult_multi_callback, &data, 0)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, NULL, ecmult_multi_callback, &data, 0)); /* Check 1- and 2-point multiplies against ecmult */ for (ncount = 0; ncount < count; ncount++) { @@ -2595,36 +2756,31 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e /* only G scalar */ secp256k1_ecmult(&ctx->ecmult_ctx, &r2, &ptgj, &szero, &sc[0]); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &sc[0], ecmult_multi_callback, &data, 0)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &sc[0], ecmult_multi_callback, &data, 0)); secp256k1_gej_neg(&r2, &r2); secp256k1_gej_add_var(&r, &r, &r2, NULL); CHECK(secp256k1_gej_is_infinity(&r)); /* 1-point */ secp256k1_ecmult(&ctx->ecmult_ctx, &r2, &ptgj, &sc[0], &szero); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 1)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 1)); secp256k1_gej_neg(&r2, &r2); secp256k1_gej_add_var(&r, &r, &r2, NULL); CHECK(secp256k1_gej_is_infinity(&r)); - /* Try to multiply 1 point, but scratch space is empty */ - scratch_empty = secp256k1_scratch_create(&ctx->error_callback, 0); - CHECK(!ecmult_multi(&ctx->ecmult_ctx, scratch_empty, &r, &szero, ecmult_multi_callback, &data, 1)); - secp256k1_scratch_destroy(scratch_empty); - /* Try to multiply 1 point, but callback returns false */ - CHECK(!ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_false_callback, &data, 1)); + CHECK(!ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_false_callback, &data, 1)); /* 2-point */ secp256k1_ecmult(&ctx->ecmult_ctx, &r2, &ptgj, &sc[0], &sc[1]); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 2)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 2)); secp256k1_gej_neg(&r2, &r2); secp256k1_gej_add_var(&r, &r, &r2, NULL); CHECK(secp256k1_gej_is_infinity(&r)); /* 2-point with G scalar */ secp256k1_ecmult(&ctx->ecmult_ctx, &r2, &ptgj, &sc[0], &sc[1]); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &sc[1], ecmult_multi_callback, &data, 1)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &sc[1], ecmult_multi_callback, &data, 1)); secp256k1_gej_neg(&r2, &r2); secp256k1_gej_add_var(&r, &r, &r2, NULL); CHECK(secp256k1_gej_is_infinity(&r)); @@ -2641,7 +2797,7 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e random_scalar_order(&sc[i]); secp256k1_ge_set_infinity(&pt[i]); } - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, sizes[j])); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, sizes[j])); CHECK(secp256k1_gej_is_infinity(&r)); } @@ -2651,7 +2807,7 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e pt[i] = ptg; secp256k1_scalar_set_int(&sc[i], 0); } - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, sizes[j])); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, sizes[j])); CHECK(secp256k1_gej_is_infinity(&r)); } @@ -2664,7 +2820,7 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e pt[2 * i + 1] = ptg; } - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, sizes[j])); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, sizes[j])); CHECK(secp256k1_gej_is_infinity(&r)); random_scalar_order(&sc[0]); @@ -2677,7 +2833,7 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e secp256k1_ge_neg(&pt[2*i+1], &pt[2*i]); } - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, sizes[j])); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, sizes[j])); CHECK(secp256k1_gej_is_infinity(&r)); } @@ -2692,7 +2848,7 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e secp256k1_scalar_negate(&sc[i], &sc[i]); } - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 32)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 32)); CHECK(secp256k1_gej_is_infinity(&r)); } @@ -2711,7 +2867,7 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e } secp256k1_ecmult(&ctx->ecmult_ctx, &r2, &r, &sc[0], &szero); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 20)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 20)); secp256k1_gej_neg(&r2, &r2); secp256k1_gej_add_var(&r, &r, &r2, NULL); CHECK(secp256k1_gej_is_infinity(&r)); @@ -2734,7 +2890,7 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e secp256k1_gej_set_ge(&p0j, &pt[0]); secp256k1_ecmult(&ctx->ecmult_ctx, &r2, &p0j, &rs, &szero); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 20)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 20)); secp256k1_gej_neg(&r2, &r2); secp256k1_gej_add_var(&r, &r, &r2, NULL); CHECK(secp256k1_gej_is_infinity(&r)); @@ -2747,13 +2903,13 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e } secp256k1_scalar_clear(&sc[0]); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 20)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 20)); secp256k1_scalar_clear(&sc[1]); secp256k1_scalar_clear(&sc[2]); secp256k1_scalar_clear(&sc[3]); secp256k1_scalar_clear(&sc[4]); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 6)); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 5)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 6)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &szero, ecmult_multi_callback, &data, 5)); CHECK(secp256k1_gej_is_infinity(&r)); /* Run through s0*(t0*P) + s1*(t1*P) exhaustively for many small values of s0, s1, t0, t1 */ @@ -2798,7 +2954,7 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e secp256k1_scalar_add(&tmp1, &tmp1, &tmp2); secp256k1_ecmult(&ctx->ecmult_ctx, &expected, &ptgj, &tmp1, &szero); - CHECK(ecmult_multi(&ctx->ecmult_ctx, scratch, &actual, &szero, ecmult_multi_callback, &data, 2)); + CHECK(ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &actual, &szero, ecmult_multi_callback, &data, 2)); secp256k1_gej_neg(&expected, &expected); secp256k1_gej_add_var(&actual, &actual, &expected, NULL); CHECK(secp256k1_gej_is_infinity(&actual)); @@ -2809,6 +2965,24 @@ void test_ecmult_multi(secp256k1_scratch *scratch, secp256k1_ecmult_multi_func e } } +void test_ecmult_multi_batch_single(secp256k1_ecmult_multi_func ecmult_multi) { + secp256k1_scalar szero; + secp256k1_scalar sc[32]; + secp256k1_ge pt[32]; + secp256k1_gej r; + ecmult_multi_data data; + secp256k1_scratch *scratch_empty; + + data.sc = sc; + data.pt = pt; + secp256k1_scalar_set_int(&szero, 0); + + /* Try to multiply 1 point, but scratch space is empty.*/ + scratch_empty = secp256k1_scratch_create(&ctx->error_callback, 0); + CHECK(!ecmult_multi(&ctx->error_callback, &ctx->ecmult_ctx, scratch_empty, &r, &szero, ecmult_multi_callback, &data, 1)); + secp256k1_scratch_destroy(&ctx->error_callback, scratch_empty); +} + void test_secp256k1_pippenger_bucket_window_inv(void) { int i; @@ -2839,17 +3013,27 @@ void test_ecmult_multi_pippenger_max_points(void) { int bucket_window = 0; for(; scratch_size < max_size; scratch_size+=256) { + size_t i; + size_t total_alloc; + size_t checkpoint; scratch = secp256k1_scratch_create(&ctx->error_callback, scratch_size); CHECK(scratch != NULL); - n_points_supported = secp256k1_pippenger_max_points(scratch); + checkpoint = secp256k1_scratch_checkpoint(&ctx->error_callback, scratch); + n_points_supported = secp256k1_pippenger_max_points(&ctx->error_callback, scratch); if (n_points_supported == 0) { - secp256k1_scratch_destroy(scratch); + secp256k1_scratch_destroy(&ctx->error_callback, scratch); continue; } bucket_window = secp256k1_pippenger_bucket_window(n_points_supported); - CHECK(secp256k1_scratch_allocate_frame(scratch, secp256k1_pippenger_scratch_size(n_points_supported, bucket_window), PIPPENGER_SCRATCH_OBJECTS)); - secp256k1_scratch_deallocate_frame(scratch); - secp256k1_scratch_destroy(scratch); + /* allocate `total_alloc` bytes over `PIPPENGER_SCRATCH_OBJECTS` many allocations */ + total_alloc = secp256k1_pippenger_scratch_size(n_points_supported, bucket_window); + for (i = 0; i < PIPPENGER_SCRATCH_OBJECTS - 1; i++) { + CHECK(secp256k1_scratch_alloc(&ctx->error_callback, scratch, 1)); + total_alloc--; + } + CHECK(secp256k1_scratch_alloc(&ctx->error_callback, scratch, total_alloc)); + secp256k1_scratch_apply_checkpoint(&ctx->error_callback, scratch, checkpoint); + secp256k1_scratch_destroy(&ctx->error_callback, scratch); } CHECK(bucket_window == PIPPENGER_MAX_BUCKET_WINDOW); } @@ -2932,19 +3116,25 @@ void test_ecmult_multi_batching(void) { } data.sc = sc; data.pt = pt; + secp256k1_gej_neg(&r2, &r2); - /* Test with empty scratch space */ + /* Test with empty scratch space. It should compute the correct result using + * ecmult_mult_simple algorithm which doesn't require a scratch space. */ scratch = secp256k1_scratch_create(&ctx->error_callback, 0); - CHECK(!secp256k1_ecmult_multi_var(&ctx->ecmult_ctx, scratch, &r, &scG, ecmult_multi_callback, &data, 1)); - secp256k1_scratch_destroy(scratch); + CHECK(secp256k1_ecmult_multi_var(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &scG, ecmult_multi_callback, &data, n_points)); + secp256k1_gej_add_var(&r, &r, &r2, NULL); + CHECK(secp256k1_gej_is_infinity(&r)); + secp256k1_scratch_destroy(&ctx->error_callback, scratch); /* Test with space for 1 point in pippenger. That's not enough because - * ecmult_multi selects strauss which requires more memory. */ + * ecmult_multi selects strauss which requires more memory. It should + * therefore select the simple algorithm. */ scratch = secp256k1_scratch_create(&ctx->error_callback, secp256k1_pippenger_scratch_size(1, 1) + PIPPENGER_SCRATCH_OBJECTS*ALIGNMENT); - CHECK(!secp256k1_ecmult_multi_var(&ctx->ecmult_ctx, scratch, &r, &scG, ecmult_multi_callback, &data, 1)); - secp256k1_scratch_destroy(scratch); + CHECK(secp256k1_ecmult_multi_var(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &scG, ecmult_multi_callback, &data, n_points)); + secp256k1_gej_add_var(&r, &r, &r2, NULL); + CHECK(secp256k1_gej_is_infinity(&r)); + secp256k1_scratch_destroy(&ctx->error_callback, scratch); - secp256k1_gej_neg(&r2, &r2); for(i = 1; i <= n_points; i++) { if (i > ECMULT_PIPPENGER_THRESHOLD) { int bucket_window = secp256k1_pippenger_bucket_window(i); @@ -2954,10 +3144,10 @@ void test_ecmult_multi_batching(void) { size_t scratch_size = secp256k1_strauss_scratch_size(i); scratch = secp256k1_scratch_create(&ctx->error_callback, scratch_size + STRAUSS_SCRATCH_OBJECTS*ALIGNMENT); } - CHECK(secp256k1_ecmult_multi_var(&ctx->ecmult_ctx, scratch, &r, &scG, ecmult_multi_callback, &data, n_points)); + CHECK(secp256k1_ecmult_multi_var(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &r, &scG, ecmult_multi_callback, &data, n_points)); secp256k1_gej_add_var(&r, &r, &r2, NULL); CHECK(secp256k1_gej_is_infinity(&r)); - secp256k1_scratch_destroy(scratch); + secp256k1_scratch_destroy(&ctx->error_callback, scratch); } free(sc); free(pt); @@ -2972,13 +3162,15 @@ void run_ecmult_multi_tests(void) { test_ecmult_multi(scratch, secp256k1_ecmult_multi_var); test_ecmult_multi(NULL, secp256k1_ecmult_multi_var); test_ecmult_multi(scratch, secp256k1_ecmult_pippenger_batch_single); + test_ecmult_multi_batch_single(secp256k1_ecmult_pippenger_batch_single); test_ecmult_multi(scratch, secp256k1_ecmult_strauss_batch_single); - secp256k1_scratch_destroy(scratch); + test_ecmult_multi_batch_single(secp256k1_ecmult_strauss_batch_single); + secp256k1_scratch_destroy(&ctx->error_callback, scratch); /* Run test_ecmult_multi with space for exactly one point */ scratch = secp256k1_scratch_create(&ctx->error_callback, secp256k1_strauss_scratch_size(1) + STRAUSS_SCRATCH_OBJECTS*ALIGNMENT); test_ecmult_multi(scratch, secp256k1_ecmult_multi_var); - secp256k1_scratch_destroy(scratch); + secp256k1_scratch_destroy(&ctx->error_callback, scratch); test_ecmult_multi_batch_size_helper(); test_ecmult_multi_batching(); @@ -3050,7 +3242,7 @@ void test_constant_wnaf(const secp256k1_scalar *number, int w) { } bits = 128; #endif - skew = secp256k1_wnaf_const(wnaf, num, w, bits); + skew = secp256k1_wnaf_const(wnaf, &num, w, bits); for (i = WNAF_SIZE_BITS(bits, w); i >= 0; --i) { secp256k1_scalar t; @@ -3786,37 +3978,57 @@ void run_eckey_edge_case_test(void) { pubkey_negone = pubkey; /* Tweak of zero leaves the value unchanged. */ memset(ctmp2, 0, 32); - CHECK(secp256k1_ec_privkey_tweak_add(ctx, ctmp, ctmp2) == 1); + CHECK(secp256k1_ec_seckey_tweak_add(ctx, ctmp, ctmp2) == 1); CHECK(memcmp(orderc, ctmp, 31) == 0 && ctmp[31] == 0x40); memcpy(&pubkey2, &pubkey, sizeof(pubkey)); CHECK(secp256k1_ec_pubkey_tweak_add(ctx, &pubkey, ctmp2) == 1); CHECK(memcmp(&pubkey, &pubkey2, sizeof(pubkey)) == 0); /* Multiply tweak of zero zeroizes the output. */ - CHECK(secp256k1_ec_privkey_tweak_mul(ctx, ctmp, ctmp2) == 0); + CHECK(secp256k1_ec_seckey_tweak_mul(ctx, ctmp, ctmp2) == 0); CHECK(memcmp(zeros, ctmp, 32) == 0); CHECK(secp256k1_ec_pubkey_tweak_mul(ctx, &pubkey, ctmp2) == 0); CHECK(memcmp(&pubkey, zeros, sizeof(pubkey)) == 0); memcpy(&pubkey, &pubkey2, sizeof(pubkey)); - /* Overflowing key tweak zeroizes. */ + /* If seckey_tweak_add or seckey_tweak_mul are called with an overflowing + seckey, the seckey is zeroized. */ + memcpy(ctmp, orderc, 32); + memset(ctmp2, 0, 32); + ctmp2[31] = 0x01; + CHECK(secp256k1_ec_seckey_verify(ctx, ctmp2) == 1); + CHECK(secp256k1_ec_seckey_verify(ctx, ctmp) == 0); + CHECK(secp256k1_ec_seckey_tweak_add(ctx, ctmp, ctmp2) == 0); + CHECK(memcmp(zeros, ctmp, 32) == 0); + memcpy(ctmp, orderc, 32); + CHECK(secp256k1_ec_seckey_tweak_mul(ctx, ctmp, ctmp2) == 0); + CHECK(memcmp(zeros, ctmp, 32) == 0); + /* If seckey_tweak_add or seckey_tweak_mul are called with an overflowing + tweak, the seckey is zeroized. */ memcpy(ctmp, orderc, 32); ctmp[31] = 0x40; - CHECK(secp256k1_ec_privkey_tweak_add(ctx, ctmp, orderc) == 0); + CHECK(secp256k1_ec_seckey_tweak_add(ctx, ctmp, orderc) == 0); CHECK(memcmp(zeros, ctmp, 32) == 0); memcpy(ctmp, orderc, 32); ctmp[31] = 0x40; - CHECK(secp256k1_ec_privkey_tweak_mul(ctx, ctmp, orderc) == 0); + CHECK(secp256k1_ec_seckey_tweak_mul(ctx, ctmp, orderc) == 0); CHECK(memcmp(zeros, ctmp, 32) == 0); memcpy(ctmp, orderc, 32); ctmp[31] = 0x40; + /* If pubkey_tweak_add or pubkey_tweak_mul are called with an overflowing + tweak, the pubkey is zeroized. */ CHECK(secp256k1_ec_pubkey_tweak_add(ctx, &pubkey, orderc) == 0); CHECK(memcmp(&pubkey, zeros, sizeof(pubkey)) == 0); memcpy(&pubkey, &pubkey2, sizeof(pubkey)); CHECK(secp256k1_ec_pubkey_tweak_mul(ctx, &pubkey, orderc) == 0); CHECK(memcmp(&pubkey, zeros, sizeof(pubkey)) == 0); memcpy(&pubkey, &pubkey2, sizeof(pubkey)); - /* Private key tweaks results in a key of zero. */ + /* If the resulting key in secp256k1_ec_seckey_tweak_add and + * secp256k1_ec_pubkey_tweak_add is 0 the functions fail and in the latter + * case the pubkey is zeroized. */ + memcpy(ctmp, orderc, 32); + ctmp[31] = 0x40; + memset(ctmp2, 0, 32); ctmp2[31] = 1; - CHECK(secp256k1_ec_privkey_tweak_add(ctx, ctmp2, ctmp) == 0); + CHECK(secp256k1_ec_seckey_tweak_add(ctx, ctmp2, ctmp) == 0); CHECK(memcmp(zeros, ctmp2, 32) == 0); ctmp2[31] = 1; CHECK(secp256k1_ec_pubkey_tweak_add(ctx, &pubkey, ctmp2) == 0); @@ -3824,7 +4036,7 @@ void run_eckey_edge_case_test(void) { memcpy(&pubkey, &pubkey2, sizeof(pubkey)); /* Tweak computation wraps and results in a key of 1. */ ctmp2[31] = 2; - CHECK(secp256k1_ec_privkey_tweak_add(ctx, ctmp2, ctmp) == 1); + CHECK(secp256k1_ec_seckey_tweak_add(ctx, ctmp2, ctmp) == 1); CHECK(memcmp(ctmp2, zeros, 31) == 0 && ctmp2[31] == 1); ctmp2[31] = 2; CHECK(secp256k1_ec_pubkey_tweak_add(ctx, &pubkey, ctmp2) == 1); @@ -3872,16 +4084,16 @@ void run_eckey_edge_case_test(void) { CHECK(ecount == 2); ecount = 0; memset(ctmp2, 0, 32); - CHECK(secp256k1_ec_privkey_tweak_add(ctx, NULL, ctmp2) == 0); + CHECK(secp256k1_ec_seckey_tweak_add(ctx, NULL, ctmp2) == 0); CHECK(ecount == 1); - CHECK(secp256k1_ec_privkey_tweak_add(ctx, ctmp, NULL) == 0); + CHECK(secp256k1_ec_seckey_tweak_add(ctx, ctmp, NULL) == 0); CHECK(ecount == 2); ecount = 0; memset(ctmp2, 0, 32); ctmp2[31] = 1; - CHECK(secp256k1_ec_privkey_tweak_mul(ctx, NULL, ctmp2) == 0); + CHECK(secp256k1_ec_seckey_tweak_mul(ctx, NULL, ctmp2) == 0); CHECK(ecount == 1); - CHECK(secp256k1_ec_privkey_tweak_mul(ctx, ctmp, NULL) == 0); + CHECK(secp256k1_ec_seckey_tweak_mul(ctx, ctmp, NULL) == 0); CHECK(ecount == 2); ecount = 0; CHECK(secp256k1_ec_pubkey_create(ctx, NULL, ctmp) == 0); @@ -3954,6 +4166,41 @@ void run_eckey_edge_case_test(void) { secp256k1_context_set_illegal_callback(ctx, NULL, NULL); } +void run_eckey_negate_test(void) { + unsigned char seckey[32]; + unsigned char seckey_tmp[32]; + + random_scalar_order_b32(seckey); + memcpy(seckey_tmp, seckey, 32); + + /* Verify negation changes the key and changes it back */ + CHECK(secp256k1_ec_seckey_negate(ctx, seckey) == 1); + CHECK(memcmp(seckey, seckey_tmp, 32) != 0); + CHECK(secp256k1_ec_seckey_negate(ctx, seckey) == 1); + CHECK(memcmp(seckey, seckey_tmp, 32) == 0); + + /* Check that privkey alias gives same result */ + CHECK(secp256k1_ec_seckey_negate(ctx, seckey) == 1); + CHECK(secp256k1_ec_privkey_negate(ctx, seckey_tmp) == 1); + CHECK(memcmp(seckey, seckey_tmp, 32) == 0); + + /* Negating all 0s fails */ + memset(seckey, 0, 32); + memset(seckey_tmp, 0, 32); + CHECK(secp256k1_ec_seckey_negate(ctx, seckey) == 0); + /* Check that seckey is not modified */ + CHECK(memcmp(seckey, seckey_tmp, 32) == 0); + + /* Negating an overflowing seckey fails and the seckey is zeroed. In this + * test, the seckey has 16 random bytes to ensure that ec_seckey_negate + * doesn't just set seckey to a constant value in case of failure. */ + random_scalar_order_b32(seckey); + memset(seckey, 0xFF, 16); + memset(seckey_tmp, 0, 32); + CHECK(secp256k1_ec_seckey_negate(ctx, seckey) == 0); + CHECK(memcmp(seckey, seckey_tmp, 32) == 0); +} + void random_sign(secp256k1_scalar *sigr, secp256k1_scalar *sigs, const secp256k1_scalar *key, const secp256k1_scalar *msg, int *recid) { secp256k1_scalar nonce; do { @@ -4093,15 +4340,22 @@ void test_ecdsa_end_to_end(void) { if (secp256k1_rand_int(3) == 0) { int ret1; int ret2; + int ret3; unsigned char rnd[32]; + unsigned char privkey_tmp[32]; secp256k1_pubkey pubkey2; secp256k1_rand256_test(rnd); - ret1 = secp256k1_ec_privkey_tweak_add(ctx, privkey, rnd); + memcpy(privkey_tmp, privkey, 32); + ret1 = secp256k1_ec_seckey_tweak_add(ctx, privkey, rnd); ret2 = secp256k1_ec_pubkey_tweak_add(ctx, &pubkey, rnd); + /* Check that privkey alias gives same result */ + ret3 = secp256k1_ec_privkey_tweak_add(ctx, privkey_tmp, rnd); CHECK(ret1 == ret2); + CHECK(ret2 == ret3); if (ret1 == 0) { return; } + CHECK(memcmp(privkey, privkey_tmp, 32) == 0); CHECK(secp256k1_ec_pubkey_create(ctx, &pubkey2, privkey) == 1); CHECK(memcmp(&pubkey, &pubkey2, sizeof(pubkey)) == 0); } @@ -4110,15 +4364,22 @@ void test_ecdsa_end_to_end(void) { if (secp256k1_rand_int(3) == 0) { int ret1; int ret2; + int ret3; unsigned char rnd[32]; + unsigned char privkey_tmp[32]; secp256k1_pubkey pubkey2; secp256k1_rand256_test(rnd); - ret1 = secp256k1_ec_privkey_tweak_mul(ctx, privkey, rnd); + memcpy(privkey_tmp, privkey, 32); + ret1 = secp256k1_ec_seckey_tweak_mul(ctx, privkey, rnd); ret2 = secp256k1_ec_pubkey_tweak_mul(ctx, &pubkey, rnd); + /* Check that privkey alias gives same result */ + ret3 = secp256k1_ec_privkey_tweak_mul(ctx, privkey_tmp, rnd); CHECK(ret1 == ret2); + CHECK(ret2 == ret3); if (ret1 == 0) { return; } + CHECK(memcmp(privkey, privkey_tmp, 32) == 0); CHECK(secp256k1_ec_pubkey_create(ctx, &pubkey2, privkey) == 1); CHECK(memcmp(&pubkey, &pubkey2, sizeof(pubkey)) == 0); } @@ -4315,7 +4576,7 @@ int test_ecdsa_der_parse(const unsigned char *sig, size_t siglen, int certainly_ if (valid_der) { ret |= (!roundtrips_der_lax) << 12; ret |= (len_der != len_der_lax) << 13; - ret |= (memcmp(roundtrip_der_lax, roundtrip_der, len_der) != 0) << 14; + ret |= ((len_der != len_der_lax) || (memcmp(roundtrip_der_lax, roundtrip_der, len_der) != 0)) << 14; } ret |= (roundtrips_der != roundtrips_der_lax) << 15; if (parsed_der) { @@ -4356,7 +4617,7 @@ int test_ecdsa_der_parse(const unsigned char *sig, size_t siglen, int certainly_ ret |= (roundtrips_der != roundtrips_openssl) << 7; if (roundtrips_openssl) { ret |= (len_der != (size_t)len_openssl) << 8; - ret |= (memcmp(roundtrip_der, roundtrip_openssl, len_der) != 0) << 9; + ret |= ((len_der != (size_t)len_openssl) || (memcmp(roundtrip_der, roundtrip_openssl, len_der) != 0)) << 9; } #endif return ret; @@ -5016,9 +5277,188 @@ void run_ecdsa_openssl(void) { # include "modules/recovery/tests_impl.h" #endif +void run_memczero_test(void) { + unsigned char buf1[6] = {1, 2, 3, 4, 5, 6}; + unsigned char buf2[sizeof(buf1)]; + + /* memczero(..., ..., 0) is a noop. */ + memcpy(buf2, buf1, sizeof(buf1)); + memczero(buf1, sizeof(buf1), 0); + CHECK(memcmp(buf1, buf2, sizeof(buf1)) == 0); + + /* memczero(..., ..., 1) zeros the buffer. */ + memset(buf2, 0, sizeof(buf2)); + memczero(buf1, sizeof(buf1) , 1); + CHECK(memcmp(buf1, buf2, sizeof(buf1)) == 0); +} + +void int_cmov_test(void) { + int r = INT_MAX; + int a = 0; + + secp256k1_int_cmov(&r, &a, 0); + CHECK(r == INT_MAX); + + r = 0; a = INT_MAX; + secp256k1_int_cmov(&r, &a, 1); + CHECK(r == INT_MAX); + + a = 0; + secp256k1_int_cmov(&r, &a, 1); + CHECK(r == 0); + + a = 1; + secp256k1_int_cmov(&r, &a, 1); + CHECK(r == 1); + + r = 1; a = 0; + secp256k1_int_cmov(&r, &a, 0); + CHECK(r == 1); + +} + +void fe_cmov_test(void) { + static const secp256k1_fe zero = SECP256K1_FE_CONST(0, 0, 0, 0, 0, 0, 0, 0); + static const secp256k1_fe one = SECP256K1_FE_CONST(0, 0, 0, 0, 0, 0, 0, 1); + static const secp256k1_fe max = SECP256K1_FE_CONST( + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL + ); + secp256k1_fe r = max; + secp256k1_fe a = zero; + + secp256k1_fe_cmov(&r, &a, 0); + CHECK(memcmp(&r, &max, sizeof(r)) == 0); + + r = zero; a = max; + secp256k1_fe_cmov(&r, &a, 1); + CHECK(memcmp(&r, &max, sizeof(r)) == 0); + + a = zero; + secp256k1_fe_cmov(&r, &a, 1); + CHECK(memcmp(&r, &zero, sizeof(r)) == 0); + + a = one; + secp256k1_fe_cmov(&r, &a, 1); + CHECK(memcmp(&r, &one, sizeof(r)) == 0); + + r = one; a = zero; + secp256k1_fe_cmov(&r, &a, 0); + CHECK(memcmp(&r, &one, sizeof(r)) == 0); +} + +void fe_storage_cmov_test(void) { + static const secp256k1_fe_storage zero = SECP256K1_FE_STORAGE_CONST(0, 0, 0, 0, 0, 0, 0, 0); + static const secp256k1_fe_storage one = SECP256K1_FE_STORAGE_CONST(0, 0, 0, 0, 0, 0, 0, 1); + static const secp256k1_fe_storage max = SECP256K1_FE_STORAGE_CONST( + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL + ); + secp256k1_fe_storage r = max; + secp256k1_fe_storage a = zero; + + secp256k1_fe_storage_cmov(&r, &a, 0); + CHECK(memcmp(&r, &max, sizeof(r)) == 0); + + r = zero; a = max; + secp256k1_fe_storage_cmov(&r, &a, 1); + CHECK(memcmp(&r, &max, sizeof(r)) == 0); + + a = zero; + secp256k1_fe_storage_cmov(&r, &a, 1); + CHECK(memcmp(&r, &zero, sizeof(r)) == 0); + + a = one; + secp256k1_fe_storage_cmov(&r, &a, 1); + CHECK(memcmp(&r, &one, sizeof(r)) == 0); + + r = one; a = zero; + secp256k1_fe_storage_cmov(&r, &a, 0); + CHECK(memcmp(&r, &one, sizeof(r)) == 0); +} + +void scalar_cmov_test(void) { + static const secp256k1_scalar zero = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 0, 0, 0); + static const secp256k1_scalar one = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 0, 0, 1); + static const secp256k1_scalar max = SECP256K1_SCALAR_CONST( + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL + ); + secp256k1_scalar r = max; + secp256k1_scalar a = zero; + + secp256k1_scalar_cmov(&r, &a, 0); + CHECK(memcmp(&r, &max, sizeof(r)) == 0); + + r = zero; a = max; + secp256k1_scalar_cmov(&r, &a, 1); + CHECK(memcmp(&r, &max, sizeof(r)) == 0); + + a = zero; + secp256k1_scalar_cmov(&r, &a, 1); + CHECK(memcmp(&r, &zero, sizeof(r)) == 0); + + a = one; + secp256k1_scalar_cmov(&r, &a, 1); + CHECK(memcmp(&r, &one, sizeof(r)) == 0); + + r = one; a = zero; + secp256k1_scalar_cmov(&r, &a, 0); + CHECK(memcmp(&r, &one, sizeof(r)) == 0); +} + +void ge_storage_cmov_test(void) { + static const secp256k1_ge_storage zero = SECP256K1_GE_STORAGE_CONST(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + static const secp256k1_ge_storage one = SECP256K1_GE_STORAGE_CONST(0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1); + static const secp256k1_ge_storage max = SECP256K1_GE_STORAGE_CONST( + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, + 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL, 0xFFFFFFFFUL + ); + secp256k1_ge_storage r = max; + secp256k1_ge_storage a = zero; + + secp256k1_ge_storage_cmov(&r, &a, 0); + CHECK(memcmp(&r, &max, sizeof(r)) == 0); + + r = zero; a = max; + secp256k1_ge_storage_cmov(&r, &a, 1); + CHECK(memcmp(&r, &max, sizeof(r)) == 0); + + a = zero; + secp256k1_ge_storage_cmov(&r, &a, 1); + CHECK(memcmp(&r, &zero, sizeof(r)) == 0); + + a = one; + secp256k1_ge_storage_cmov(&r, &a, 1); + CHECK(memcmp(&r, &one, sizeof(r)) == 0); + + r = one; a = zero; + secp256k1_ge_storage_cmov(&r, &a, 0); + CHECK(memcmp(&r, &one, sizeof(r)) == 0); +} + +void run_cmov_tests(void) { + int_cmov_test(); + fe_cmov_test(); + fe_storage_cmov_test(); + scalar_cmov_test(); + ge_storage_cmov_test(); +} + int main(int argc, char **argv) { unsigned char seed16[16] = {0}; unsigned char run32[32] = {0}; + + /* Disable buffering for stdout to improve reliability of getting + * diagnostic information. Happens right at the start of main because + * setbuf must be used before any other operation on the stream. */ + setbuf(stdout, NULL); + /* Also disable buffering for stderr because it's not guaranteed that it's + * unbuffered on all systems. */ + setbuf(stderr, NULL); + /* find iteration count */ if (argc > 1) { count = strtol(argv[1], NULL, 0); @@ -5030,7 +5470,7 @@ int main(int argc, char **argv) { const char* ch = argv[2]; while (pos < 16 && ch[0] != 0 && ch[1] != 0) { unsigned short sh; - if (sscanf(ch, "%2hx", &sh)) { + if ((sscanf(ch, "%2hx", &sh)) == 1) { seed16[pos] = sh; } else { break; @@ -5062,7 +5502,8 @@ int main(int argc, char **argv) { printf("random seed = %02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x\n", seed16[0], seed16[1], seed16[2], seed16[3], seed16[4], seed16[5], seed16[6], seed16[7], seed16[8], seed16[9], seed16[10], seed16[11], seed16[12], seed16[13], seed16[14], seed16[15]); /* initialize */ - run_context_tests(); + run_context_tests(0); + run_context_tests(1); run_scratch_tests(); ctx = secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); if (secp256k1_rand_bits(1)) { @@ -5119,6 +5560,9 @@ int main(int argc, char **argv) { /* EC key edge cases */ run_eckey_edge_case_test(); + /* EC key arithmetic test */ + run_eckey_negate_test(); + #ifdef ENABLE_MODULE_ECDH /* ecdh tests */ run_ecdh_tests(); @@ -5139,6 +5583,11 @@ int main(int argc, char **argv) { run_recovery_tests(); #endif + /* util tests */ + run_memczero_test(); + + run_cmov_tests(); + secp256k1_rand256(run32); printf("random run = %02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x%02x\n", run32[0], run32[1], run32[2], run32[3], run32[4], run32[5], run32[6], run32[7], run32[8], run32[9], run32[10], run32[11], run32[12], run32[13], run32[14], run32[15]); diff --git a/src/secp256k1/src/tests_exhaustive.c b/src/secp256k1/src/tests_exhaustive.c index ab9779b02f..8cca1cef21 100644 --- a/src/secp256k1/src/tests_exhaustive.c +++ b/src/secp256k1/src/tests_exhaustive.c @@ -142,7 +142,7 @@ void test_exhaustive_addition(const secp256k1_ge *group, const secp256k1_gej *gr for (i = 0; i < order; i++) { secp256k1_gej tmp; if (i > 0) { - secp256k1_gej_double_nonzero(&tmp, &groupj[i], NULL); + secp256k1_gej_double_nonzero(&tmp, &groupj[i]); ge_equals_gej(&group[(2 * i) % order], &tmp); } secp256k1_gej_double_var(&tmp, &groupj[i], NULL); @@ -212,14 +212,14 @@ void test_exhaustive_ecmult_multi(const secp256k1_context *ctx, const secp256k1_ data.pt[0] = group[x]; data.pt[1] = group[y]; - secp256k1_ecmult_multi_var(&ctx->ecmult_ctx, scratch, &tmp, &g_sc, ecmult_multi_callback, &data, 2); + secp256k1_ecmult_multi_var(&ctx->error_callback, &ctx->ecmult_ctx, scratch, &tmp, &g_sc, ecmult_multi_callback, &data, 2); ge_equals_gej(&group[(i * x + j * y + k) % order], &tmp); } } } } } - secp256k1_scratch_destroy(scratch); + secp256k1_scratch_destroy(&ctx->error_callback, scratch); } void r_from_k(secp256k1_scalar *r, const secp256k1_ge *group, int k) { diff --git a/src/secp256k1/src/util.h b/src/secp256k1/src/util.h index e1f5b76452..8289e23e0c 100644 --- a/src/secp256k1/src/util.h +++ b/src/secp256k1/src/util.h @@ -14,6 +14,7 @@ #include <stdlib.h> #include <stdint.h> #include <stdio.h> +#include <limits.h> typedef struct { void (*fn)(const char *text, void* data); @@ -68,6 +69,25 @@ static SECP256K1_INLINE void secp256k1_callback_call(const secp256k1_callback * #define VERIFY_SETUP(stmt) #endif +/* Define `VG_UNDEF` and `VG_CHECK` when VALGRIND is defined */ +#if !defined(VG_CHECK) +# if defined(VALGRIND) +# include <valgrind/memcheck.h> +# define VG_UNDEF(x,y) VALGRIND_MAKE_MEM_UNDEFINED((x),(y)) +# define VG_CHECK(x,y) VALGRIND_CHECK_MEM_IS_DEFINED((x),(y)) +# else +# define VG_UNDEF(x,y) +# define VG_CHECK(x,y) +# endif +#endif + +/* Like `VG_CHECK` but on VERIFY only */ +#if defined(VERIFY) +#define VG_CHECK_VERIFY(x,y) VG_CHECK((x), (y)) +#else +#define VG_CHECK_VERIFY(x,y) +#endif + static SECP256K1_INLINE void *checked_malloc(const secp256k1_callback* cb, size_t size) { void *ret = malloc(size); if (ret == NULL) { @@ -84,6 +104,47 @@ static SECP256K1_INLINE void *checked_realloc(const secp256k1_callback* cb, void return ret; } +#if defined(__BIGGEST_ALIGNMENT__) +#define ALIGNMENT __BIGGEST_ALIGNMENT__ +#else +/* Using 16 bytes alignment because common architectures never have alignment + * requirements above 8 for any of the types we care about. In addition we + * leave some room because currently we don't care about a few bytes. */ +#define ALIGNMENT 16 +#endif + +#define ROUND_TO_ALIGN(size) (((size + ALIGNMENT - 1) / ALIGNMENT) * ALIGNMENT) + +/* Assume there is a contiguous memory object with bounds [base, base + max_size) + * of which the memory range [base, *prealloc_ptr) is already allocated for usage, + * where *prealloc_ptr is an aligned pointer. In that setting, this functions + * reserves the subobject [*prealloc_ptr, *prealloc_ptr + alloc_size) of + * alloc_size bytes by increasing *prealloc_ptr accordingly, taking into account + * alignment requirements. + * + * The function returns an aligned pointer to the newly allocated subobject. + * + * This is useful for manual memory management: if we're simply given a block + * [base, base + max_size), the caller can use this function to allocate memory + * in this block and keep track of the current allocation state with *prealloc_ptr. + * + * It is VERIFY_CHECKed that there is enough space left in the memory object and + * *prealloc_ptr is aligned relative to base. + */ +static SECP256K1_INLINE void *manual_alloc(void** prealloc_ptr, size_t alloc_size, void* base, size_t max_size) { + size_t aligned_alloc_size = ROUND_TO_ALIGN(alloc_size); + void* ret; + VERIFY_CHECK(prealloc_ptr != NULL); + VERIFY_CHECK(*prealloc_ptr != NULL); + VERIFY_CHECK(base != NULL); + VERIFY_CHECK((unsigned char*)*prealloc_ptr >= (unsigned char*)base); + VERIFY_CHECK(((unsigned char*)*prealloc_ptr - (unsigned char*)base) % ALIGNMENT == 0); + VERIFY_CHECK((unsigned char*)*prealloc_ptr - (unsigned char*)base + aligned_alloc_size <= max_size); + ret = *prealloc_ptr; + *((unsigned char**)prealloc_ptr) += aligned_alloc_size; + return ret; +} + /* Macro for restrict, when available and not in a VERIFY build. */ #if defined(SECP256K1_BUILD) && defined(VERIFY) # define SECP256K1_RESTRICT @@ -118,4 +179,33 @@ static SECP256K1_INLINE void *checked_realloc(const secp256k1_callback* cb, void SECP256K1_GNUC_EXT typedef unsigned __int128 uint128_t; #endif +/* Zero memory if flag == 1. Flag must be 0 or 1. Constant time. */ +static SECP256K1_INLINE void memczero(void *s, size_t len, int flag) { + unsigned char *p = (unsigned char *)s; + /* Access flag with a volatile-qualified lvalue. + This prevents clang from figuring out (after inlining) that flag can + take only be 0 or 1, which leads to variable time code. */ + volatile int vflag = flag; + unsigned char mask = -(unsigned char) vflag; + while (len) { + *p &= ~mask; + p++; + len--; + } +} + +/** If flag is true, set *r equal to *a; otherwise leave it. Constant-time. Both *r and *a must be initialized and non-negative.*/ +static SECP256K1_INLINE void secp256k1_int_cmov(int *r, const int *a, int flag) { + unsigned int mask0, mask1, r_masked, a_masked; + /* Casting a negative int to unsigned and back to int is implementation defined behavior */ + VERIFY_CHECK(*r >= 0 && *a >= 0); + + mask0 = (unsigned int)flag + ~0u; + mask1 = ~mask0; + r_masked = ((unsigned int)*r & mask0); + a_masked = ((unsigned int)*a & mask1); + + *r = (int)(r_masked | a_masked); +} + #endif /* SECP256K1_UTIL_H */ diff --git a/src/secp256k1/src/valgrind_ctime_test.c b/src/secp256k1/src/valgrind_ctime_test.c new file mode 100644 index 0000000000..60a82d599e --- /dev/null +++ b/src/secp256k1/src/valgrind_ctime_test.c @@ -0,0 +1,119 @@ +/********************************************************************** + * Copyright (c) 2020 Gregory Maxwell * + * Distributed under the MIT software license, see the accompanying * + * file COPYING or http://www.opensource.org/licenses/mit-license.php.* + **********************************************************************/ + +#include <valgrind/memcheck.h> +#include "include/secp256k1.h" +#include "util.h" + +#if ENABLE_MODULE_ECDH +# include "include/secp256k1_ecdh.h" +#endif + +#if ENABLE_MODULE_RECOVERY +# include "include/secp256k1_recovery.h" +#endif + +int main(void) { + secp256k1_context* ctx; + secp256k1_ecdsa_signature signature; + secp256k1_pubkey pubkey; + size_t siglen = 74; + size_t outputlen = 33; + int i; + int ret; + unsigned char msg[32]; + unsigned char key[32]; + unsigned char sig[74]; + unsigned char spubkey[33]; +#if ENABLE_MODULE_RECOVERY + secp256k1_ecdsa_recoverable_signature recoverable_signature; + int recid; +#endif + + if (!RUNNING_ON_VALGRIND) { + fprintf(stderr, "This test can only usefully be run inside valgrind.\n"); + fprintf(stderr, "Usage: libtool --mode=execute valgrind ./valgrind_ctime_test\n"); + exit(1); + } + + /** In theory, testing with a single secret input should be sufficient: + * If control flow depended on secrets the tool would generate an error. + */ + for (i = 0; i < 32; i++) { + key[i] = i + 65; + } + for (i = 0; i < 32; i++) { + msg[i] = i + 1; + } + + ctx = secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_DECLASSIFY); + + /* Test keygen. */ + VALGRIND_MAKE_MEM_UNDEFINED(key, 32); + ret = secp256k1_ec_pubkey_create(ctx, &pubkey, key); + VALGRIND_MAKE_MEM_DEFINED(&pubkey, sizeof(secp256k1_pubkey)); + VALGRIND_MAKE_MEM_DEFINED(&ret, sizeof(ret)); + CHECK(ret); + CHECK(secp256k1_ec_pubkey_serialize(ctx, spubkey, &outputlen, &pubkey, SECP256K1_EC_COMPRESSED) == 1); + + /* Test signing. */ + VALGRIND_MAKE_MEM_UNDEFINED(key, 32); + ret = secp256k1_ecdsa_sign(ctx, &signature, msg, key, NULL, NULL); + VALGRIND_MAKE_MEM_DEFINED(&signature, sizeof(secp256k1_ecdsa_signature)); + VALGRIND_MAKE_MEM_DEFINED(&ret, sizeof(ret)); + CHECK(ret); + CHECK(secp256k1_ecdsa_signature_serialize_der(ctx, sig, &siglen, &signature)); + +#if ENABLE_MODULE_ECDH + /* Test ECDH. */ + VALGRIND_MAKE_MEM_UNDEFINED(key, 32); + ret = secp256k1_ecdh(ctx, msg, &pubkey, key, NULL, NULL); + VALGRIND_MAKE_MEM_DEFINED(&ret, sizeof(ret)); + CHECK(ret == 1); +#endif + +#if ENABLE_MODULE_RECOVERY + /* Test signing a recoverable signature. */ + VALGRIND_MAKE_MEM_UNDEFINED(key, 32); + ret = secp256k1_ecdsa_sign_recoverable(ctx, &recoverable_signature, msg, key, NULL, NULL); + VALGRIND_MAKE_MEM_DEFINED(&recoverable_signature, sizeof(recoverable_signature)); + VALGRIND_MAKE_MEM_DEFINED(&ret, sizeof(ret)); + CHECK(ret); + CHECK(secp256k1_ecdsa_recoverable_signature_serialize_compact(ctx, sig, &recid, &recoverable_signature)); + CHECK(recid >= 0 && recid <= 3); +#endif + + VALGRIND_MAKE_MEM_UNDEFINED(key, 32); + ret = secp256k1_ec_seckey_verify(ctx, key); + VALGRIND_MAKE_MEM_DEFINED(&ret, sizeof(ret)); + CHECK(ret == 1); + + VALGRIND_MAKE_MEM_UNDEFINED(key, 32); + ret = secp256k1_ec_seckey_negate(ctx, key); + VALGRIND_MAKE_MEM_DEFINED(&ret, sizeof(ret)); + CHECK(ret == 1); + + VALGRIND_MAKE_MEM_UNDEFINED(key, 32); + VALGRIND_MAKE_MEM_UNDEFINED(msg, 32); + ret = secp256k1_ec_seckey_tweak_add(ctx, key, msg); + VALGRIND_MAKE_MEM_DEFINED(&ret, sizeof(ret)); + CHECK(ret == 1); + + VALGRIND_MAKE_MEM_UNDEFINED(key, 32); + VALGRIND_MAKE_MEM_UNDEFINED(msg, 32); + ret = secp256k1_ec_seckey_tweak_mul(ctx, key, msg); + VALGRIND_MAKE_MEM_DEFINED(&ret, sizeof(ret)); + CHECK(ret == 1); + + /* Test context randomisation. Do this last because it leaves the context tainted. */ + VALGRIND_MAKE_MEM_UNDEFINED(key, 32); + ret = secp256k1_context_randomize(ctx, key); + VALGRIND_MAKE_MEM_DEFINED(&ret, sizeof(ret)); + CHECK(ret); + + secp256k1_context_destroy(ctx); + return 0; +} diff --git a/src/serialize.h b/src/serialize.h index 71c2cfa164..7a94e704b2 100644 --- a/src/serialize.h +++ b/src/serialize.h @@ -9,13 +9,13 @@ #include <compat/endian.h> #include <algorithm> +#include <cstdint> #include <cstring> #include <ios> #include <limits> #include <map> #include <memory> #include <set> -#include <stdint.h> #include <string> #include <string.h> #include <utility> @@ -272,7 +272,7 @@ template<typename Stream> inline void Unserialize(Stream& s, bool& a) { char f=s inline unsigned int GetSizeOfCompactSize(uint64_t nSize) { if (nSize < 253) return sizeof(unsigned char); - else if (nSize <= std::numeric_limits<unsigned short>::max()) return sizeof(unsigned char) + sizeof(unsigned short); + else if (nSize <= std::numeric_limits<uint16_t>::max()) return sizeof(unsigned char) + sizeof(uint16_t); else if (nSize <= std::numeric_limits<unsigned int>::max()) return sizeof(unsigned char) + sizeof(unsigned int); else return sizeof(unsigned char) + sizeof(uint64_t); } @@ -286,7 +286,7 @@ void WriteCompactSize(Stream& os, uint64_t nSize) { ser_writedata8(os, nSize); } - else if (nSize <= std::numeric_limits<unsigned short>::max()) + else if (nSize <= std::numeric_limits<uint16_t>::max()) { ser_writedata8(os, 253); ser_writedata16(os, nSize); diff --git a/src/span.h b/src/span.h index 73b37e35f3..4afb383a59 100644 --- a/src/span.h +++ b/src/span.h @@ -10,20 +10,101 @@ #include <algorithm> #include <assert.h> +#ifdef DEBUG +#define CONSTEXPR_IF_NOT_DEBUG +#define ASSERT_IF_DEBUG(x) assert((x)) +#else +#define CONSTEXPR_IF_NOT_DEBUG constexpr +#define ASSERT_IF_DEBUG(x) +#endif + /** A Span is an object that can refer to a contiguous sequence of objects. * * It implements a subset of C++20's std::span. + * + * Things to be aware of when writing code that deals with Spans: + * + * - Similar to references themselves, Spans are subject to reference lifetime + * issues. The user is responsible for making sure the objects pointed to by + * a Span live as long as the Span is used. For example: + * + * std::vector<int> vec{1,2,3,4}; + * Span<int> sp(vec); + * vec.push_back(5); + * printf("%i\n", sp.front()); // UB! + * + * may exhibit undefined behavior, as increasing the size of a vector may + * invalidate references. + * + * - One particular pitfall is that Spans can be constructed from temporaries, + * but this is unsafe when the Span is stored in a variable, outliving the + * temporary. For example, this will compile, but exhibits undefined behavior: + * + * Span<const int> sp(std::vector<int>{1, 2, 3}); + * printf("%i\n", sp.front()); // UB! + * + * The lifetime of the vector ends when the statement it is created in ends. + * Thus the Span is left with a dangling reference, and using it is undefined. + * + * - Due to Span's automatic creation from range-like objects (arrays, and data + * types that expose a data() and size() member function), functions that + * accept a Span as input parameter can be called with any compatible + * range-like object. For example, this works: +* + * void Foo(Span<const int> arg); + * + * Foo(std::vector<int>{1, 2, 3}); // Works + * + * This is very useful in cases where a function truly does not care about the + * container, and only about having exactly a range of elements. However it + * may also be surprising to see automatic conversions in this case. + * + * When a function accepts a Span with a mutable element type, it will not + * accept temporaries; only variables or other references. For example: + * + * void FooMut(Span<int> arg); + * + * FooMut(std::vector<int>{1, 2, 3}); // Does not compile + * std::vector<int> baz{1, 2, 3}; + * FooMut(baz); // Works + * + * This is similar to how functions that take (non-const) lvalue references + * as input cannot accept temporaries. This does not work either: + * + * void FooVec(std::vector<int>& arg); + * FooVec(std::vector<int>{1, 2, 3}); // Does not compile + * + * The idea is that if a function accepts a mutable reference, a meaningful + * result will be present in that variable after the call. Passing a temporary + * is useless in that context. */ template<typename C> class Span { C* m_data; - std::ptrdiff_t m_size; + std::size_t m_size; public: constexpr Span() noexcept : m_data(nullptr), m_size(0) {} - constexpr Span(C* data, std::ptrdiff_t size) noexcept : m_data(data), m_size(size) {} - constexpr Span(C* data, C* end) noexcept : m_data(data), m_size(end - data) {} + + /** Construct a span from a begin pointer and a size. + * + * This implements a subset of the iterator-based std::span constructor in C++20, + * which is hard to implement without std::address_of. + */ + template <typename T, typename std::enable_if<std::is_convertible<T (*)[], C (*)[]>::value, int>::type = 0> + constexpr Span(T* begin, std::size_t size) noexcept : m_data(begin), m_size(size) {} + + /** Construct a span from a begin and end pointer. + * + * This implements a subset of the iterator-based std::span constructor in C++20, + * which is hard to implement without std::address_of. + */ + template <typename T, typename std::enable_if<std::is_convertible<T (*)[], C (*)[]>::value, int>::type = 0> + CONSTEXPR_IF_NOT_DEBUG Span(T* begin, T* end) noexcept : m_data(begin), m_size(end - begin) + { + ASSERT_IF_DEBUG(end >= begin); + } /** Implicit conversion of spans between compatible types. * @@ -42,18 +123,60 @@ public: /** Default assignment operator. */ Span& operator=(const Span& other) noexcept = default; + /** Construct a Span from an array. This matches the corresponding C++20 std::span constructor. */ + template <int N> + constexpr Span(C (&a)[N]) noexcept : m_data(a), m_size(N) {} + + /** Construct a Span for objects with .data() and .size() (std::string, std::array, std::vector, ...). + * + * This implements a subset of the functionality provided by the C++20 std::span range-based constructor. + * + * To prevent surprises, only Spans for constant value types are supported when passing in temporaries. + * Note that this restriction does not exist when converting arrays or other Spans (see above). + */ + template <typename V, typename std::enable_if<(std::is_const<C>::value || std::is_lvalue_reference<V>::value) && std::is_convertible<typename std::remove_pointer<decltype(std::declval<V&>().data())>::type (*)[], C (*)[]>::value && std::is_convertible<decltype(std::declval<V&>().size()), std::size_t>::value, int>::type = 0> + constexpr Span(V&& v) noexcept : m_data(v.data()), m_size(v.size()) {} + constexpr C* data() const noexcept { return m_data; } constexpr C* begin() const noexcept { return m_data; } constexpr C* end() const noexcept { return m_data + m_size; } - constexpr C& front() const noexcept { return m_data[0]; } - constexpr C& back() const noexcept { return m_data[m_size - 1]; } - constexpr std::ptrdiff_t size() const noexcept { return m_size; } - constexpr C& operator[](std::ptrdiff_t pos) const noexcept { return m_data[pos]; } - - constexpr Span<C> subspan(std::ptrdiff_t offset) const noexcept { return Span<C>(m_data + offset, m_size - offset); } - constexpr Span<C> subspan(std::ptrdiff_t offset, std::ptrdiff_t count) const noexcept { return Span<C>(m_data + offset, count); } - constexpr Span<C> first(std::ptrdiff_t count) const noexcept { return Span<C>(m_data, count); } - constexpr Span<C> last(std::ptrdiff_t count) const noexcept { return Span<C>(m_data + m_size - count, count); } + CONSTEXPR_IF_NOT_DEBUG C& front() const noexcept + { + ASSERT_IF_DEBUG(size() > 0); + return m_data[0]; + } + CONSTEXPR_IF_NOT_DEBUG C& back() const noexcept + { + ASSERT_IF_DEBUG(size() > 0); + return m_data[m_size - 1]; + } + constexpr std::size_t size() const noexcept { return m_size; } + constexpr bool empty() const noexcept { return size() == 0; } + CONSTEXPR_IF_NOT_DEBUG C& operator[](std::size_t pos) const noexcept + { + ASSERT_IF_DEBUG(size() > pos); + return m_data[pos]; + } + CONSTEXPR_IF_NOT_DEBUG Span<C> subspan(std::size_t offset) const noexcept + { + ASSERT_IF_DEBUG(size() >= offset); + return Span<C>(m_data + offset, m_size - offset); + } + CONSTEXPR_IF_NOT_DEBUG Span<C> subspan(std::size_t offset, std::size_t count) const noexcept + { + ASSERT_IF_DEBUG(size() >= offset + count); + return Span<C>(m_data + offset, count); + } + CONSTEXPR_IF_NOT_DEBUG Span<C> first(std::size_t count) const noexcept + { + ASSERT_IF_DEBUG(size() >= count); + return Span<C>(m_data, count); + } + CONSTEXPR_IF_NOT_DEBUG Span<C> last(std::size_t count) const noexcept + { + ASSERT_IF_DEBUG(size() >= count); + return Span<C>(m_data + m_size - count, count); + } friend constexpr bool operator==(const Span& a, const Span& b) noexcept { return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin()); } friend constexpr bool operator!=(const Span& a, const Span& b) noexcept { return !(a == b); } @@ -65,29 +188,35 @@ public: template <typename O> friend class Span; }; -/** Create a span to a container exposing data() and size(). - * - * This correctly deals with constness: the returned Span's element type will be - * whatever data() returns a pointer to. If either the passed container is const, - * or its element type is const, the resulting span will have a const element type. - * - * std::span will have a constructor that implements this functionality directly. - */ -template<typename A, int N> -constexpr Span<A> MakeSpan(A (&a)[N]) { return Span<A>(a, N); } - -template<typename V> -constexpr Span<typename std::remove_pointer<decltype(std::declval<V>().data())>::type> MakeSpan(V& v) { return Span<typename std::remove_pointer<decltype(std::declval<V>().data())>::type>(v.data(), v.size()); } +// MakeSpan helps constructing a Span of the right type automatically. +/** MakeSpan for arrays: */ +template <typename A, int N> Span<A> constexpr MakeSpan(A (&a)[N]) { return Span<A>(a, N); } +/** MakeSpan for temporaries / rvalue references, only supporting const output. */ +template <typename V> constexpr auto MakeSpan(V&& v) -> typename std::enable_if<!std::is_lvalue_reference<V>::value, Span<const typename std::remove_pointer<decltype(v.data())>::type>>::type { return std::forward<V>(v); } +/** MakeSpan for (lvalue) references, supporting mutable output. */ +template <typename V> constexpr auto MakeSpan(V& v) -> Span<typename std::remove_pointer<decltype(v.data())>::type> { return v; } /** Pop the last element off a span, and return a reference to that element. */ template <typename T> T& SpanPopBack(Span<T>& span) { size_t size = span.size(); - assert(size > 0); + ASSERT_IF_DEBUG(size > 0); T& back = span[size - 1]; span = Span<T>(span.data(), size - 1); return back; } +// Helper functions to safely cast to unsigned char pointers. +inline unsigned char* UCharCast(char* c) { return (unsigned char*)c; } +inline unsigned char* UCharCast(unsigned char* c) { return c; } +inline const unsigned char* UCharCast(const char* c) { return (unsigned char*)c; } +inline const unsigned char* UCharCast(const unsigned char* c) { return c; } + +// Helper function to safely convert a Span to a Span<[const] unsigned char>. +template <typename T> constexpr auto UCharSpanCast(Span<T> s) -> Span<typename std::remove_pointer<decltype(UCharCast(s.data()))>::type> { return {UCharCast(s.data()), s.size()}; } + +/** Like MakeSpan, but for (const) unsigned char member types only. Only works for (un)signed char containers. */ +template <typename V> constexpr auto MakeUCharSpan(V&& v) -> decltype(UCharSpanCast(MakeSpan(std::forward<V>(v)))) { return UCharSpanCast(MakeSpan(std::forward<V>(v))); } + #endif diff --git a/src/streams.h b/src/streams.h index e1d1b0eab2..6ce8065da8 100644 --- a/src/streams.h +++ b/src/streams.h @@ -814,18 +814,6 @@ public: return true; } - bool Seek(uint64_t nPos) { - long nLongPos = nPos; - if (nPos != (uint64_t)nLongPos) - return false; - if (fseek(src, nLongPos, SEEK_SET)) - return false; - nLongPos = ftell(src); - nSrcPos = nLongPos; - nReadPos = nLongPos; - return true; - } - //! prevent reading beyond a certain position //! no argument removes the limit bool SetLimit(uint64_t nPos = std::numeric_limits<uint64_t>::max()) { diff --git a/src/support/lockedpool.cpp b/src/support/lockedpool.cpp index f17b539e09..b4f392116c 100644 --- a/src/support/lockedpool.cpp +++ b/src/support/lockedpool.cpp @@ -29,7 +29,6 @@ #endif LockedPoolManager* LockedPoolManager::_instance = nullptr; -std::once_flag LockedPoolManager::init_flag; /*******************************************************************************/ // Utilities diff --git a/src/support/lockedpool.h b/src/support/lockedpool.h index de668f0773..b9e2e99d1a 100644 --- a/src/support/lockedpool.h +++ b/src/support/lockedpool.h @@ -221,7 +221,8 @@ public: /** Return the current instance, or create it once */ static LockedPoolManager& Instance() { - std::call_once(LockedPoolManager::init_flag, LockedPoolManager::CreateInstance); + static std::once_flag init_flag; + std::call_once(init_flag, LockedPoolManager::CreateInstance); return *LockedPoolManager::_instance; } @@ -234,7 +235,6 @@ private: static bool LockingFailed(); static LockedPoolManager* _instance; - static std::once_flag init_flag; }; #endif // BITCOIN_SUPPORT_LOCKEDPOOL_H diff --git a/src/sync.cpp b/src/sync.cpp index 9abdedbed4..4be13a3c48 100644 --- a/src/sync.cpp +++ b/src/sync.cpp @@ -60,7 +60,7 @@ struct CLockLocation { std::string ToString() const { return strprintf( - "%s %s:%s%s (in thread %s)", + "'%s' in %s:%s%s (in thread '%s')", mutexName, sourceFile, sourceLine, (fTry ? " (TRY)" : ""), m_thread_name); } @@ -105,7 +105,7 @@ static void potential_deadlock_detected(const LockPair& mismatch, const LockStac { LogPrintf("POTENTIAL DEADLOCK DETECTED\n"); LogPrintf("Previous lock order was:\n"); - for (const LockStackItem& i : s2) { + for (const LockStackItem& i : s1) { if (i.first == mismatch.first) { LogPrintf(" (1)"); /* Continued */ } @@ -114,21 +114,25 @@ static void potential_deadlock_detected(const LockPair& mismatch, const LockStac } LogPrintf(" %s\n", i.second.ToString()); } + + std::string mutex_a, mutex_b; LogPrintf("Current lock order is:\n"); - for (const LockStackItem& i : s1) { + for (const LockStackItem& i : s2) { if (i.first == mismatch.first) { LogPrintf(" (1)"); /* Continued */ + mutex_a = i.second.Name(); } if (i.first == mismatch.second) { LogPrintf(" (2)"); /* Continued */ + mutex_b = i.second.Name(); } LogPrintf(" %s\n", i.second.ToString()); } if (g_debug_lockorder_abort) { - tfm::format(std::cerr, "Assertion failed: detected inconsistent lock order at %s:%i, details in debug log.\n", __FILE__, __LINE__); + tfm::format(std::cerr, "Assertion failed: detected inconsistent lock order for %s, details in debug log.\n", s2.back().second.ToString()); abort(); } - throw std::logic_error("potential deadlock detected"); + throw std::logic_error(strprintf("potential deadlock detected: %s -> %s -> %s", mutex_b, mutex_a, mutex_b)); } static void push_lock(void* c, const CLockLocation& locklocation) @@ -145,12 +149,17 @@ static void push_lock(void* c, const CLockLocation& locklocation) const LockPair p1 = std::make_pair(i.first, c); if (lockdata.lockorders.count(p1)) continue; - lockdata.lockorders.emplace(p1, lock_stack); const LockPair p2 = std::make_pair(c, i.first); + if (lockdata.lockorders.count(p2)) { + auto lock_stack_copy = lock_stack; + lock_stack.pop_back(); + potential_deadlock_detected(p1, lockdata.lockorders[p2], lock_stack_copy); + // potential_deadlock_detected() does not return. + } + + lockdata.lockorders.emplace(p1, lock_stack); lockdata.invlockorders.insert(p2); - if (lockdata.lockorders.count(p2)) - potential_deadlock_detected(p1, lockdata.lockorders[p2], lockdata.lockorders[p1]); } } @@ -255,6 +264,17 @@ void DeleteLock(void* cs) } } +bool LockStackEmpty() +{ + LockData& lockdata = GetLockData(); + std::lock_guard<std::mutex> lock(lockdata.dd_mutex); + const auto it = lockdata.m_lock_stacks.find(std::this_thread::get_id()); + if (it == lockdata.m_lock_stacks.end()) { + return true; + } + return it->second.empty(); +} + bool g_debug_lockorder_abort = true; #endif /* DEBUG_LOCKORDER */ diff --git a/src/sync.h b/src/sync.h index 60e5a87aec..05ff2ee8a9 100644 --- a/src/sync.h +++ b/src/sync.h @@ -56,6 +56,7 @@ template <typename MutexType> void AssertLockHeldInternal(const char* pszName, const char* pszFile, int nLine, MutexType* cs) ASSERT_EXCLUSIVE_LOCK(cs); void AssertLockNotHeldInternal(const char* pszName, const char* pszFile, int nLine, void* cs); void DeleteLock(void* cs); +bool LockStackEmpty(); /** * Call abort() if a potential lock order deadlock bug is detected, instead of @@ -64,13 +65,14 @@ void DeleteLock(void* cs); */ extern bool g_debug_lockorder_abort; #else -void static inline EnterCritical(const char* pszName, const char* pszFile, int nLine, void* cs, bool fTry = false) {} -void static inline LeaveCritical() {} -void static inline CheckLastCritical(void* cs, std::string& lockname, const char* guardname, const char* file, int line) {} +inline void EnterCritical(const char* pszName, const char* pszFile, int nLine, void* cs, bool fTry = false) {} +inline void LeaveCritical() {} +inline void CheckLastCritical(void* cs, std::string& lockname, const char* guardname, const char* file, int line) {} template <typename MutexType> -void static inline AssertLockHeldInternal(const char* pszName, const char* pszFile, int nLine, MutexType* cs) ASSERT_EXCLUSIVE_LOCK(cs) {} -void static inline AssertLockNotHeldInternal(const char* pszName, const char* pszFile, int nLine, void* cs) {} -void static inline DeleteLock(void* cs) {} +inline void AssertLockHeldInternal(const char* pszName, const char* pszFile, int nLine, MutexType* cs) ASSERT_EXCLUSIVE_LOCK(cs) {} +inline void AssertLockNotHeldInternal(const char* pszName, const char* pszFile, int nLine, void* cs) {} +inline void DeleteLock(void* cs) {} +inline bool LockStackEmpty() { return true; } #endif #define AssertLockHeld(cs) AssertLockHeldInternal(#cs, __FILE__, __LINE__, &cs) #define AssertLockNotHeld(cs) AssertLockNotHeldInternal(#cs, __FILE__, __LINE__, &cs) @@ -103,6 +105,12 @@ public: } using UniqueLock = std::unique_lock<PARENT>; +#ifdef __clang__ + //! For negative capabilities in the Clang Thread Safety Analysis. + //! A negative requirement uses the EXCLUSIVE_LOCKS_REQUIRED attribute, in conjunction + //! with the ! operator, to indicate that a mutex should not be held. + const AnnotatedMixin& operator!() const { return *this; } +#endif // __clang__ }; /** diff --git a/src/test/addrman_tests.cpp b/src/test/addrman_tests.cpp index bc6b38c682..25fdd64568 100644 --- a/src/test/addrman_tests.cpp +++ b/src/test/addrman_tests.cpp @@ -392,7 +392,7 @@ BOOST_AUTO_TEST_CASE(addrman_getaddr) // Test: Sanity check, GetAddr should never return anything if addrman // is empty. BOOST_CHECK_EQUAL(addrman.size(), 0U); - std::vector<CAddress> vAddr1 = addrman.GetAddr(); + std::vector<CAddress> vAddr1 = addrman.GetAddr(/* max_addresses */ 0, /* max_pct */0); BOOST_CHECK_EQUAL(vAddr1.size(), 0U); CAddress addr1 = CAddress(ResolveService("250.250.2.1", 8333), NODE_NONE); @@ -415,13 +415,15 @@ BOOST_AUTO_TEST_CASE(addrman_getaddr) BOOST_CHECK(addrman.Add(addr4, source2)); BOOST_CHECK(addrman.Add(addr5, source1)); - // GetAddr returns 23% of addresses, 23% of 5 is 1 rounded down. - BOOST_CHECK_EQUAL(addrman.GetAddr().size(), 1U); + BOOST_CHECK_EQUAL(addrman.GetAddr(/* max_addresses */ 0, /* max_pct */ 0).size(), 5U); + // Net processing asks for 23% of addresses. 23% of 5 is 1 rounded down. + BOOST_CHECK_EQUAL(addrman.GetAddr(/* max_addresses */ 2500, /* max_pct */ 23).size(), 1U); // Test: Ensure GetAddr works with new and tried addresses. addrman.Good(CAddress(addr1, NODE_NONE)); addrman.Good(CAddress(addr2, NODE_NONE)); - BOOST_CHECK_EQUAL(addrman.GetAddr().size(), 1U); + BOOST_CHECK_EQUAL(addrman.GetAddr(/* max_addresses */ 0, /* max_pct */ 0).size(), 5U); + BOOST_CHECK_EQUAL(addrman.GetAddr(/* max_addresses */ 2500, /* max_pct */ 23).size(), 1U); // Test: Ensure GetAddr still returns 23% when addrman has many addrs. for (unsigned int i = 1; i < (8 * 256); i++) { @@ -436,7 +438,7 @@ BOOST_AUTO_TEST_CASE(addrman_getaddr) if (i % 8 == 0) addrman.Good(addr); } - std::vector<CAddress> vAddr = addrman.GetAddr(); + std::vector<CAddress> vAddr = addrman.GetAddr(/* max_addresses */ 2500, /* max_pct */ 23); size_t percent23 = (addrman.size() * 23) / 100; BOOST_CHECK_EQUAL(vAddr.size(), percent23); diff --git a/src/test/blockfilter_index_tests.cpp b/src/test/blockfilter_index_tests.cpp index 7dff2e6e86..00c4bdc14e 100644 --- a/src/test/blockfilter_index_tests.cpp +++ b/src/test/blockfilter_index_tests.cpp @@ -94,7 +94,7 @@ bool BuildChainTestingSetup::BuildChain(const CBlockIndex* pindex, CBlockHeader header = block->GetBlockHeader(); BlockValidationState state; - if (!EnsureChainman(m_node).ProcessNewBlockHeaders({header}, state, Params(), &pindex)) { + if (!Assert(m_node.chainman)->ProcessNewBlockHeaders({header}, state, Params(), &pindex)) { return false; } } @@ -171,7 +171,7 @@ BOOST_FIXTURE_TEST_CASE(blockfilter_index_initial_sync, BuildChainTestingSetup) uint256 chainA_last_header = last_header; for (size_t i = 0; i < 2; i++) { const auto& block = chainA[i]; - BOOST_REQUIRE(EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, nullptr)); + BOOST_REQUIRE(Assert(m_node.chainman)->ProcessNewBlock(Params(), block, true, nullptr)); } for (size_t i = 0; i < 2; i++) { const auto& block = chainA[i]; @@ -189,7 +189,7 @@ BOOST_FIXTURE_TEST_CASE(blockfilter_index_initial_sync, BuildChainTestingSetup) uint256 chainB_last_header = last_header; for (size_t i = 0; i < 3; i++) { const auto& block = chainB[i]; - BOOST_REQUIRE(EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, nullptr)); + BOOST_REQUIRE(Assert(m_node.chainman)->ProcessNewBlock(Params(), block, true, nullptr)); } for (size_t i = 0; i < 3; i++) { const auto& block = chainB[i]; @@ -220,7 +220,7 @@ BOOST_FIXTURE_TEST_CASE(blockfilter_index_initial_sync, BuildChainTestingSetup) // Reorg back to chain A. for (size_t i = 2; i < 4; i++) { const auto& block = chainA[i]; - BOOST_REQUIRE(EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, nullptr)); + BOOST_REQUIRE(Assert(m_node.chainman)->ProcessNewBlock(Params(), block, true, nullptr)); } // Check that chain A and B blocks can be retrieved. diff --git a/src/test/checkqueue_tests.cpp b/src/test/checkqueue_tests.cpp index 35750b2ebc..8348810ac1 100644 --- a/src/test/checkqueue_tests.cpp +++ b/src/test/checkqueue_tests.cpp @@ -10,7 +10,7 @@ #include <util/time.h> #include <boost/test/unit_test.hpp> -#include <boost/thread.hpp> +#include <boost/thread/thread.hpp> #include <atomic> #include <condition_variable> diff --git a/src/test/coins_tests.cpp b/src/test/coins_tests.cpp index 60196c36a5..173ec5e3d9 100644 --- a/src/test/coins_tests.cpp +++ b/src/test/coins_tests.cpp @@ -538,7 +538,7 @@ BOOST_AUTO_TEST_CASE(ccoins_serialization) CDataStream tmp(SER_DISK, CLIENT_VERSION); uint64_t x = 3000000000ULL; tmp << VARINT(x); - BOOST_CHECK_EQUAL(HexStr(tmp.begin(), tmp.end()), "8a95c0bb00"); + BOOST_CHECK_EQUAL(HexStr(tmp), "8a95c0bb00"); CDataStream ss5(ParseHex("00008a95c0bb00"), SER_DISK, CLIENT_VERSION); try { Coin cc5; diff --git a/src/test/crypto_tests.cpp b/src/test/crypto_tests.cpp index f64251fe32..b3cc8cefd9 100644 --- a/src/test/crypto_tests.cpp +++ b/src/test/crypto_tests.cpp @@ -183,7 +183,7 @@ static void TestHKDF_SHA256_32(const std::string &ikm_hex, const std::string &sa CHKDF_HMAC_SHA256_L32 hkdf32(initial_key_material.data(), initial_key_material.size(), salt_stringified); unsigned char out[32]; hkdf32.Expand32(info_stringified, out); - BOOST_CHECK(HexStr(out, out + 32) == okm_check_hex); + BOOST_CHECK(HexStr(out) == okm_check_hex); } static std::string LongTestString() @@ -743,7 +743,7 @@ BOOST_AUTO_TEST_CASE(sha256d64) in[j] = InsecureRandBits(8); } for (int j = 0; j < i; ++j) { - CHash256().Write(in + 64 * j, 64).Finalize(out1 + 32 * j); + CHash256().Write({in + 64 * j, 64}).Finalize({out1 + 32 * j, 32}); } SHA256D64(out2, in, i); BOOST_CHECK(memcmp(out1, out2, 32 * i) == 0); diff --git a/src/test/data/script_tests.json b/src/test/data/script_tests.json index c01ef307b7..724789bbf9 100644 --- a/src/test/data/script_tests.json +++ b/src/test/data/script_tests.json @@ -678,7 +678,7 @@ ["0 0x02 0x0000 0", "CHECKMULTISIGVERIFY 1", "", "OK"], ["While not really correctly DER encoded, the empty signature is allowed by"], -["STRICTENC to provide a compact way to provide a delibrately invalid signature."], +["STRICTENC to provide a compact way to provide a deliberately invalid signature."], ["0", "0x21 0x02865c40293a680cb9c020e7b1e106d8c1916d3cef99aa431a56d253e69256dac0 CHECKSIG NOT", "STRICTENC", "OK"], ["0 0", "1 0x21 0x02865c40293a680cb9c020e7b1e106d8c1916d3cef99aa431a56d253e69256dac0 1 CHECKMULTISIG NOT", "STRICTENC", "OK"], diff --git a/src/test/denialofservice_tests.cpp b/src/test/denialofservice_tests.cpp index 348b170536..0115803e58 100644 --- a/src/test/denialofservice_tests.cpp +++ b/src/test/denialofservice_tests.cpp @@ -4,10 +4,12 @@ // Unit tests for denial-of-service detection/prevention code +#include <arith_uint256.h> #include <banman.h> #include <chainparams.h> #include <net.h> #include <net_processing.h> +#include <pubkey.h> #include <script/sign.h> #include <script/signingprovider.h> #include <script/standard.h> @@ -82,7 +84,7 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) // Mock an outbound peer CAddress addr1(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode1(id++, ServiceFlags(NODE_NETWORK|NODE_WITNESS), 0, INVALID_SOCKET, addr1, 0, 0, CAddress(), "", /*fInboundIn=*/ false); + CNode dummyNode1(id++, ServiceFlags(NODE_NETWORK|NODE_WITNESS), 0, INVALID_SOCKET, addr1, 0, 0, CAddress(), "", ConnectionType::OUTBOUND); dummyNode1.SetSendVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode1); @@ -98,11 +100,11 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) // Test starts here { - LOCK2(cs_main, dummyNode1.cs_sendProcessing); + LOCK(dummyNode1.cs_sendProcessing); BOOST_CHECK(peerLogic->SendMessages(&dummyNode1)); // should result in getheaders } { - LOCK2(cs_main, dummyNode1.cs_vSend); + LOCK(dummyNode1.cs_vSend); BOOST_CHECK(dummyNode1.vSendMsg.size() > 0); dummyNode1.vSendMsg.clear(); } @@ -111,17 +113,17 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) // Wait 21 minutes SetMockTime(nStartTime+21*60); { - LOCK2(cs_main, dummyNode1.cs_sendProcessing); + LOCK(dummyNode1.cs_sendProcessing); BOOST_CHECK(peerLogic->SendMessages(&dummyNode1)); // should result in getheaders } { - LOCK2(cs_main, dummyNode1.cs_vSend); + LOCK(dummyNode1.cs_vSend); BOOST_CHECK(dummyNode1.vSendMsg.size() > 0); } // Wait 3 more minutes SetMockTime(nStartTime+24*60); { - LOCK2(cs_main, dummyNode1.cs_sendProcessing); + LOCK(dummyNode1.cs_sendProcessing); BOOST_CHECK(peerLogic->SendMessages(&dummyNode1)); // should result in disconnect } BOOST_CHECK(dummyNode1.fDisconnect == true); @@ -134,7 +136,7 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction) static void AddRandomOutboundPeer(std::vector<CNode *> &vNodes, PeerLogicValidation &peerLogic, CConnmanTest* connman) { CAddress addr(ip(g_insecure_rand_ctx.randbits(32)), NODE_NONE); - vNodes.emplace_back(new CNode(id++, ServiceFlags(NODE_NETWORK|NODE_WITNESS), 0, INVALID_SOCKET, addr, 0, 0, CAddress(), "", /*fInboundIn=*/ false)); + vNodes.emplace_back(new CNode(id++, ServiceFlags(NODE_NETWORK|NODE_WITNESS), 0, INVALID_SOCKET, addr, 0, 0, CAddress(), "", ConnectionType::OUTBOUND)); CNode &node = *vNodes.back(); node.SetSendVersion(PROTOCOL_VERSION); @@ -217,7 +219,7 @@ BOOST_AUTO_TEST_CASE(stale_tip_peer_management) connman->ClearNodes(); } -BOOST_AUTO_TEST_CASE(DoS_banning) +BOOST_AUTO_TEST_CASE(peer_discouragement) { auto banman = MakeUnique<BanMan>(GetDataDir() / "banlist.dat", nullptr, DEFAULT_MISBEHAVING_BANTIME); auto connman = MakeUnique<CConnman>(0x1337, 0x1337); @@ -225,100 +227,54 @@ BOOST_AUTO_TEST_CASE(DoS_banning) banman->ClearBanned(); CAddress addr1(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode1(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr1, 0, 0, CAddress(), "", true); + CNode dummyNode1(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr1, 0, 0, CAddress(), "", ConnectionType::INBOUND); dummyNode1.SetSendVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode1); dummyNode1.nVersion = 1; dummyNode1.fSuccessfullyConnected = true; { LOCK(cs_main); - Misbehaving(dummyNode1.GetId(), 100); // Should get banned + Misbehaving(dummyNode1.GetId(), DISCOURAGEMENT_THRESHOLD); // Should be discouraged } { - LOCK2(cs_main, dummyNode1.cs_sendProcessing); + LOCK(dummyNode1.cs_sendProcessing); BOOST_CHECK(peerLogic->SendMessages(&dummyNode1)); } - BOOST_CHECK(banman->IsBanned(addr1)); - BOOST_CHECK(!banman->IsBanned(ip(0xa0b0c001|0x0000ff00))); // Different IP, not banned + BOOST_CHECK(banman->IsDiscouraged(addr1)); + BOOST_CHECK(!banman->IsDiscouraged(ip(0xa0b0c001|0x0000ff00))); // Different IP, not discouraged CAddress addr2(ip(0xa0b0c002), NODE_NONE); - CNode dummyNode2(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr2, 1, 1, CAddress(), "", true); + CNode dummyNode2(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr2, 1, 1, CAddress(), "", ConnectionType::INBOUND); dummyNode2.SetSendVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode2); dummyNode2.nVersion = 1; dummyNode2.fSuccessfullyConnected = true; { LOCK(cs_main); - Misbehaving(dummyNode2.GetId(), 50); + Misbehaving(dummyNode2.GetId(), DISCOURAGEMENT_THRESHOLD - 1); } { - LOCK2(cs_main, dummyNode2.cs_sendProcessing); + LOCK(dummyNode2.cs_sendProcessing); BOOST_CHECK(peerLogic->SendMessages(&dummyNode2)); } - BOOST_CHECK(!banman->IsBanned(addr2)); // 2 not banned yet... - BOOST_CHECK(banman->IsBanned(addr1)); // ... but 1 still should be + BOOST_CHECK(!banman->IsDiscouraged(addr2)); // 2 not discouraged yet... + BOOST_CHECK(banman->IsDiscouraged(addr1)); // ... but 1 still should be { LOCK(cs_main); - Misbehaving(dummyNode2.GetId(), 50); + Misbehaving(dummyNode2.GetId(), 1); // 2 reaches discouragement threshold } { - LOCK2(cs_main, dummyNode2.cs_sendProcessing); + LOCK(dummyNode2.cs_sendProcessing); BOOST_CHECK(peerLogic->SendMessages(&dummyNode2)); } - BOOST_CHECK(banman->IsBanned(addr2)); + BOOST_CHECK(banman->IsDiscouraged(addr1)); // Expect both 1 and 2 + BOOST_CHECK(banman->IsDiscouraged(addr2)); // to be discouraged now bool dummy; peerLogic->FinalizeNode(dummyNode1.GetId(), dummy); peerLogic->FinalizeNode(dummyNode2.GetId(), dummy); } -BOOST_AUTO_TEST_CASE(DoS_banscore) -{ - auto banman = MakeUnique<BanMan>(GetDataDir() / "banlist.dat", nullptr, DEFAULT_MISBEHAVING_BANTIME); - auto connman = MakeUnique<CConnman>(0x1337, 0x1337); - auto peerLogic = MakeUnique<PeerLogicValidation>(connman.get(), banman.get(), *m_node.scheduler, *m_node.chainman, *m_node.mempool); - - banman->ClearBanned(); - gArgs.ForceSetArg("-banscore", "111"); // because 11 is my favorite number - CAddress addr1(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode1(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr1, 3, 1, CAddress(), "", true); - dummyNode1.SetSendVersion(PROTOCOL_VERSION); - peerLogic->InitializeNode(&dummyNode1); - dummyNode1.nVersion = 1; - dummyNode1.fSuccessfullyConnected = true; - { - LOCK(cs_main); - Misbehaving(dummyNode1.GetId(), 100); - } - { - LOCK2(cs_main, dummyNode1.cs_sendProcessing); - BOOST_CHECK(peerLogic->SendMessages(&dummyNode1)); - } - BOOST_CHECK(!banman->IsBanned(addr1)); - { - LOCK(cs_main); - Misbehaving(dummyNode1.GetId(), 10); - } - { - LOCK2(cs_main, dummyNode1.cs_sendProcessing); - BOOST_CHECK(peerLogic->SendMessages(&dummyNode1)); - } - BOOST_CHECK(!banman->IsBanned(addr1)); - { - LOCK(cs_main); - Misbehaving(dummyNode1.GetId(), 1); - } - { - LOCK2(cs_main, dummyNode1.cs_sendProcessing); - BOOST_CHECK(peerLogic->SendMessages(&dummyNode1)); - } - BOOST_CHECK(banman->IsBanned(addr1)); - gArgs.ForceSetArg("-banscore", ToString(DEFAULT_BANSCORE_THRESHOLD)); - - bool dummy; - peerLogic->FinalizeNode(dummyNode1.GetId(), dummy); -} - BOOST_AUTO_TEST_CASE(DoS_bantime) { auto banman = MakeUnique<BanMan>(GetDataDir() / "banlist.dat", nullptr, DEFAULT_MISBEHAVING_BANTIME); @@ -330,7 +286,7 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) SetMockTime(nStartTime); // Overrides future calls to GetTime() CAddress addr(ip(0xa0b0c001), NODE_NONE); - CNode dummyNode(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr, 4, 4, CAddress(), "", true); + CNode dummyNode(id++, NODE_NETWORK, 0, INVALID_SOCKET, addr, 4, 4, CAddress(), "", ConnectionType::INBOUND); dummyNode.SetSendVersion(PROTOCOL_VERSION); peerLogic->InitializeNode(&dummyNode); dummyNode.nVersion = 1; @@ -338,19 +294,13 @@ BOOST_AUTO_TEST_CASE(DoS_bantime) { LOCK(cs_main); - Misbehaving(dummyNode.GetId(), 100); + Misbehaving(dummyNode.GetId(), DISCOURAGEMENT_THRESHOLD); } { - LOCK2(cs_main, dummyNode.cs_sendProcessing); + LOCK(dummyNode.cs_sendProcessing); BOOST_CHECK(peerLogic->SendMessages(&dummyNode)); } - BOOST_CHECK(banman->IsBanned(addr)); - - SetMockTime(nStartTime+60*60); - BOOST_CHECK(banman->IsBanned(addr)); - - SetMockTime(nStartTime+60*60*24+1); - BOOST_CHECK(!banman->IsBanned(addr)); + BOOST_CHECK(banman->IsDiscouraged(addr)); bool dummy; peerLogic->FinalizeNode(dummyNode.GetId(), dummy); @@ -366,10 +316,26 @@ static CTransactionRef RandomOrphan() return it->second.tx; } +static void MakeNewKeyWithFastRandomContext(CKey& key) +{ + std::vector<unsigned char> keydata; + keydata = g_insecure_rand_ctx.randbytes(32); + key.Set(keydata.data(), keydata.data() + keydata.size(), /*fCompressedIn*/ true); + assert(key.IsValid()); +} + BOOST_AUTO_TEST_CASE(DoS_mapOrphans) { + // This test had non-deterministic coverage due to + // randomly selected seeds. + // This seed is chosen so that all branches of the function + // ecdsa_signature_parse_der_lax are executed during this test. + // Specifically branches that run only when an ECDSA + // signature's R and S values have leading zeros. + g_insecure_rand_ctx = FastRandomContext(ArithToUint256(arith_uint256(33))); + CKey key; - key.MakeNewKey(true); + MakeNewKeyWithFastRandomContext(key); FillableSigningProvider keystore; BOOST_CHECK(keystore.AddKey(key)); diff --git a/src/test/descriptor_tests.cpp b/src/test/descriptor_tests.cpp index 5f9a78ceb2..20132d5782 100644 --- a/src/test/descriptor_tests.cpp +++ b/src/test/descriptor_tests.cpp @@ -135,7 +135,7 @@ void DoCheck(const std::string& prv, const std::string& pub, int flags, const st // When the descriptor is hardened, evaluate with access to the private keys inside. const FlatSigningProvider& key_provider = (flags & HARDENED) ? keys_priv : keys_pub; - // Evaluate the descriptor selected by `t` in poisition `i`. + // Evaluate the descriptor selected by `t` in position `i`. FlatSigningProvider script_provider, script_provider_cached; std::vector<CScript> spks, spks_cached; DescriptorCache desc_cache; @@ -216,7 +216,7 @@ void DoCheck(const std::string& prv, const std::string& pub, int flags, const st // For each of the produced scripts, verify solvability, and when possible, try to sign a transaction spending it. for (size_t n = 0; n < spks.size(); ++n) { - BOOST_CHECK_EQUAL(ref[n], HexStr(spks[n].begin(), spks[n].end())); + BOOST_CHECK_EQUAL(ref[n], HexStr(spks[n])); BOOST_CHECK_EQUAL(IsSolvable(Merge(key_provider, script_provider), spks[n]), (flags & UNSOLVABLE) == 0); if (flags & SIGNABLE) { diff --git a/src/test/fuzz/addrdb.cpp b/src/test/fuzz/addrdb.cpp index 524cea83fe..16b1cb755a 100644 --- a/src/test/fuzz/addrdb.cpp +++ b/src/test/fuzz/addrdb.cpp @@ -17,19 +17,13 @@ void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); + // The point of this code is to exercise all CBanEntry constructors. const CBanEntry ban_entry = [&] { - switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 3)) { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 2)) { case 0: return CBanEntry{fuzzed_data_provider.ConsumeIntegral<int64_t>()}; break; - case 1: - return CBanEntry{fuzzed_data_provider.ConsumeIntegral<int64_t>(), fuzzed_data_provider.PickValueInArray<BanReason>({ - BanReason::BanReasonUnknown, - BanReason::BanReasonNodeMisbehaving, - BanReason::BanReasonManuallyAdded, - })}; - break; - case 2: { + case 1: { const std::optional<CBanEntry> ban_entry = ConsumeDeserializable<CBanEntry>(fuzzed_data_provider); if (ban_entry) { return *ban_entry; @@ -39,5 +33,5 @@ void test_one_input(const std::vector<uint8_t>& buffer) } return CBanEntry{}; }(); - assert(!ban_entry.banReasonToString().empty()); + (void)ban_entry; // currently unused } diff --git a/src/test/fuzz/autofile.cpp b/src/test/fuzz/autofile.cpp new file mode 100644 index 0000000000..7ea0bdd2a7 --- /dev/null +++ b/src/test/fuzz/autofile.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <optional.h> +#include <streams.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <array> +#include <cstdint> +#include <iostream> +#include <optional> +#include <string> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + FuzzedAutoFileProvider fuzzed_auto_file_provider = ConsumeAutoFile(fuzzed_data_provider); + CAutoFile auto_file = fuzzed_auto_file_provider.open(); + while (fuzzed_data_provider.ConsumeBool()) { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 5)) { + case 0: { + std::array<uint8_t, 4096> arr{}; + try { + auto_file.read((char*)arr.data(), fuzzed_data_provider.ConsumeIntegralInRange<size_t>(0, 4096)); + } catch (const std::ios_base::failure&) { + } + break; + } + case 1: { + const std::array<uint8_t, 4096> arr{}; + try { + auto_file.write((const char*)arr.data(), fuzzed_data_provider.ConsumeIntegralInRange<size_t>(0, 4096)); + } catch (const std::ios_base::failure&) { + } + break; + } + case 2: { + try { + auto_file.ignore(fuzzed_data_provider.ConsumeIntegralInRange<size_t>(0, 4096)); + } catch (const std::ios_base::failure&) { + } + break; + } + case 3: { + auto_file.fclose(); + break; + } + case 4: { + ReadFromStream(fuzzed_data_provider, auto_file); + break; + } + case 5: { + WriteToStream(fuzzed_data_provider, auto_file); + break; + } + } + } + (void)auto_file.Get(); + (void)auto_file.GetType(); + (void)auto_file.GetVersion(); + (void)auto_file.IsNull(); + if (fuzzed_data_provider.ConsumeBool()) { + FILE* f = auto_file.release(); + if (f != nullptr) { + fclose(f); + } + } +} diff --git a/src/test/fuzz/banman.cpp b/src/test/fuzz/banman.cpp new file mode 100644 index 0000000000..fc4a1d9261 --- /dev/null +++ b/src/test/fuzz/banman.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <banman.h> +#include <fs.h> +#include <netaddress.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> +#include <util/system.h> + +#include <cstdint> +#include <limits> +#include <string> +#include <vector> + +namespace { +int64_t ConsumeBanTimeOffset(FuzzedDataProvider& fuzzed_data_provider) noexcept +{ + // Avoid signed integer overflow by capping to int32_t max: + // banman.cpp:137:73: runtime error: signed integer overflow: 1591700817 + 9223372036854775807 cannot be represented in type 'long' + return fuzzed_data_provider.ConsumeIntegralInRange<int64_t>(std::numeric_limits<int64_t>::min(), std::numeric_limits<int32_t>::max()); +} +} // namespace + +void initialize() +{ + InitializeFuzzingContext(); +} + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + const fs::path banlist_file = GetDataDir() / "fuzzed_banlist.dat"; + fs::remove(banlist_file); + { + BanMan ban_man{banlist_file, nullptr, ConsumeBanTimeOffset(fuzzed_data_provider)}; + while (fuzzed_data_provider.ConsumeBool()) { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 11)) { + case 0: { + ban_man.Ban(ConsumeNetAddr(fuzzed_data_provider), + ConsumeBanTimeOffset(fuzzed_data_provider), fuzzed_data_provider.ConsumeBool()); + break; + } + case 1: { + ban_man.Ban(ConsumeSubNet(fuzzed_data_provider), + ConsumeBanTimeOffset(fuzzed_data_provider), fuzzed_data_provider.ConsumeBool()); + break; + } + case 2: { + ban_man.ClearBanned(); + break; + } + case 4: { + ban_man.IsBanned(ConsumeNetAddr(fuzzed_data_provider)); + break; + } + case 5: { + ban_man.IsBanned(ConsumeSubNet(fuzzed_data_provider)); + break; + } + case 6: { + ban_man.Unban(ConsumeNetAddr(fuzzed_data_provider)); + break; + } + case 7: { + ban_man.Unban(ConsumeSubNet(fuzzed_data_provider)); + break; + } + case 8: { + banmap_t banmap; + ban_man.GetBanned(banmap); + break; + } + case 9: { + ban_man.DumpBanlist(); + break; + } + case 11: { + ban_man.Discourage(ConsumeNetAddr(fuzzed_data_provider)); + break; + } + } + } + } + fs::remove(banlist_file); +} diff --git a/src/test/fuzz/buffered_file.cpp b/src/test/fuzz/buffered_file.cpp new file mode 100644 index 0000000000..e575640be5 --- /dev/null +++ b/src/test/fuzz/buffered_file.cpp @@ -0,0 +1,74 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <optional.h> +#include <streams.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <array> +#include <cstdint> +#include <iostream> +#include <optional> +#include <string> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + FuzzedFileProvider fuzzed_file_provider = ConsumeFile(fuzzed_data_provider); + std::optional<CBufferedFile> opt_buffered_file; + FILE* fuzzed_file = fuzzed_file_provider.open(); + try { + opt_buffered_file.emplace(fuzzed_file, fuzzed_data_provider.ConsumeIntegralInRange<uint64_t>(0, 4096), fuzzed_data_provider.ConsumeIntegralInRange<uint64_t>(0, 4096), fuzzed_data_provider.ConsumeIntegral<int>(), fuzzed_data_provider.ConsumeIntegral<int>()); + } catch (const std::ios_base::failure&) { + if (fuzzed_file != nullptr) { + fclose(fuzzed_file); + } + } + if (opt_buffered_file && fuzzed_file != nullptr) { + bool setpos_fail = false; + while (fuzzed_data_provider.ConsumeBool()) { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 4)) { + case 0: { + std::array<uint8_t, 4096> arr{}; + try { + opt_buffered_file->read((char*)arr.data(), fuzzed_data_provider.ConsumeIntegralInRange<size_t>(0, 4096)); + } catch (const std::ios_base::failure&) { + } + break; + } + case 1: { + opt_buffered_file->SetLimit(fuzzed_data_provider.ConsumeIntegralInRange<uint64_t>(0, 4096)); + break; + } + case 2: { + if (!opt_buffered_file->SetPos(fuzzed_data_provider.ConsumeIntegralInRange<uint64_t>(0, 4096))) { + setpos_fail = true; + } + break; + } + case 3: { + if (setpos_fail) { + // Calling FindByte(...) after a failed SetPos(...) call may result in an infinite loop. + break; + } + try { + opt_buffered_file->FindByte(fuzzed_data_provider.ConsumeIntegral<char>()); + } catch (const std::ios_base::failure&) { + } + break; + } + case 4: { + ReadFromStream(fuzzed_data_provider, *opt_buffered_file); + break; + } + } + } + opt_buffered_file->GetPos(); + opt_buffered_file->GetType(); + opt_buffered_file->GetVersion(); + } +} diff --git a/src/test/fuzz/coins_view.cpp b/src/test/fuzz/coins_view.cpp index 52dd62a145..c186bef7ae 100644 --- a/src/test/fuzz/coins_view.cpp +++ b/src/test/fuzz/coins_view.cpp @@ -278,7 +278,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) CCoinsStats stats; bool expected_code_path = false; try { - (void)GetUTXOStats(&coins_view_cache, stats); + (void)GetUTXOStats(&coins_view_cache, stats, CoinStatsHashType::HASH_SERIALIZED); } catch (const std::logic_error&) { expected_code_path = true; } diff --git a/src/test/fuzz/crypto.cpp b/src/test/fuzz/crypto.cpp new file mode 100644 index 0000000000..3edcf96495 --- /dev/null +++ b/src/test/fuzz/crypto.cpp @@ -0,0 +1,123 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <crypto/hmac_sha256.h> +#include <crypto/hmac_sha512.h> +#include <crypto/ripemd160.h> +#include <crypto/sha1.h> +#include <crypto/sha256.h> +#include <crypto/sha512.h> +#include <hash.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <cstdint> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + std::vector<uint8_t> data = ConsumeRandomLengthByteVector(fuzzed_data_provider); + if (data.empty()) { + data.resize(fuzzed_data_provider.ConsumeIntegralInRange<size_t>(1, 4096), fuzzed_data_provider.ConsumeIntegral<uint8_t>()); + } + + CHash160 hash160; + CHash256 hash256; + CHMAC_SHA256 hmac_sha256{data.data(), data.size()}; + CHMAC_SHA512 hmac_sha512{data.data(), data.size()}; + CRIPEMD160 ripemd160; + CSHA1 sha1; + CSHA256 sha256; + CSHA512 sha512; + CSipHasher sip_hasher{fuzzed_data_provider.ConsumeIntegral<uint64_t>(), fuzzed_data_provider.ConsumeIntegral<uint64_t>()}; + + while (fuzzed_data_provider.ConsumeBool()) { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 2)) { + case 0: { + if (fuzzed_data_provider.ConsumeBool()) { + data = ConsumeRandomLengthByteVector(fuzzed_data_provider); + if (data.empty()) { + data.resize(fuzzed_data_provider.ConsumeIntegralInRange<size_t>(1, 4096), fuzzed_data_provider.ConsumeIntegral<uint8_t>()); + } + } + + (void)hash160.Write(data); + (void)hash256.Write(data); + (void)hmac_sha256.Write(data.data(), data.size()); + (void)hmac_sha512.Write(data.data(), data.size()); + (void)ripemd160.Write(data.data(), data.size()); + (void)sha1.Write(data.data(), data.size()); + (void)sha256.Write(data.data(), data.size()); + (void)sha512.Write(data.data(), data.size()); + (void)sip_hasher.Write(data.data(), data.size()); + + (void)Hash(data); + (void)Hash160(data); + (void)sha512.Size(); + break; + } + case 1: { + (void)hash160.Reset(); + (void)hash256.Reset(); + (void)ripemd160.Reset(); + (void)sha1.Reset(); + (void)sha256.Reset(); + (void)sha512.Reset(); + break; + } + case 2: { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 8)) { + case 0: { + data.resize(CHash160::OUTPUT_SIZE); + hash160.Finalize(data); + break; + } + case 1: { + data.resize(CHash256::OUTPUT_SIZE); + hash256.Finalize(data); + break; + } + case 2: { + data.resize(CHMAC_SHA256::OUTPUT_SIZE); + hmac_sha256.Finalize(data.data()); + break; + } + case 3: { + data.resize(CHMAC_SHA512::OUTPUT_SIZE); + hmac_sha512.Finalize(data.data()); + break; + } + case 4: { + data.resize(CRIPEMD160::OUTPUT_SIZE); + ripemd160.Finalize(data.data()); + break; + } + case 5: { + data.resize(CSHA1::OUTPUT_SIZE); + sha1.Finalize(data.data()); + break; + } + case 6: { + data.resize(CSHA256::OUTPUT_SIZE); + sha256.Finalize(data.data()); + break; + } + case 7: { + data.resize(CSHA512::OUTPUT_SIZE); + sha512.Finalize(data.data()); + break; + } + case 8: { + data.resize(1); + data[0] = sip_hasher.Finalize() % 256; + break; + } + } + break; + } + } + } +} diff --git a/src/test/fuzz/crypto_aes256.cpp b/src/test/fuzz/crypto_aes256.cpp new file mode 100644 index 0000000000..ae14073c96 --- /dev/null +++ b/src/test/fuzz/crypto_aes256.cpp @@ -0,0 +1,30 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <crypto/aes.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <cassert> +#include <cstdint> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + const std::vector<uint8_t> key = ConsumeFixedLengthByteVector(fuzzed_data_provider, AES256_KEYSIZE); + + AES256Encrypt encrypt{key.data()}; + AES256Decrypt decrypt{key.data()}; + + while (fuzzed_data_provider.ConsumeBool()) { + const std::vector<uint8_t> plaintext = ConsumeFixedLengthByteVector(fuzzed_data_provider, AES_BLOCKSIZE); + std::vector<uint8_t> ciphertext(AES_BLOCKSIZE); + encrypt.Encrypt(ciphertext.data(), plaintext.data()); + std::vector<uint8_t> decrypted_plaintext(AES_BLOCKSIZE); + decrypt.Decrypt(decrypted_plaintext.data(), ciphertext.data()); + assert(decrypted_plaintext == plaintext); + } +} diff --git a/src/test/fuzz/crypto_aes256cbc.cpp b/src/test/fuzz/crypto_aes256cbc.cpp new file mode 100644 index 0000000000..52983c7e79 --- /dev/null +++ b/src/test/fuzz/crypto_aes256cbc.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <crypto/aes.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <cassert> +#include <cstdint> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + const std::vector<uint8_t> key = ConsumeFixedLengthByteVector(fuzzed_data_provider, AES256_KEYSIZE); + const std::vector<uint8_t> iv = ConsumeFixedLengthByteVector(fuzzed_data_provider, AES_BLOCKSIZE); + const bool pad = fuzzed_data_provider.ConsumeBool(); + + AES256CBCEncrypt encrypt{key.data(), iv.data(), pad}; + AES256CBCDecrypt decrypt{key.data(), iv.data(), pad}; + + while (fuzzed_data_provider.ConsumeBool()) { + const std::vector<uint8_t> plaintext = ConsumeRandomLengthByteVector(fuzzed_data_provider); + std::vector<uint8_t> ciphertext(plaintext.size() + AES_BLOCKSIZE); + const int encrypt_ret = encrypt.Encrypt(plaintext.data(), plaintext.size(), ciphertext.data()); + ciphertext.resize(encrypt_ret); + std::vector<uint8_t> decrypted_plaintext(ciphertext.size()); + const int decrypt_ret = decrypt.Decrypt(ciphertext.data(), ciphertext.size(), decrypted_plaintext.data()); + decrypted_plaintext.resize(decrypt_ret); + assert(decrypted_plaintext == plaintext || (!pad && plaintext.size() % AES_BLOCKSIZE != 0 && encrypt_ret == 0 && decrypt_ret == 0)); + } +} diff --git a/src/test/fuzz/crypto_chacha20.cpp b/src/test/fuzz/crypto_chacha20.cpp new file mode 100644 index 0000000000..b7438d312d --- /dev/null +++ b/src/test/fuzz/crypto_chacha20.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <crypto/chacha20.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <cstdint> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + + ChaCha20 chacha20; + if (fuzzed_data_provider.ConsumeBool()) { + const std::vector<unsigned char> key = ConsumeFixedLengthByteVector(fuzzed_data_provider, fuzzed_data_provider.ConsumeIntegralInRange<size_t>(16, 32)); + chacha20 = ChaCha20{key.data(), key.size()}; + } + while (fuzzed_data_provider.ConsumeBool()) { + switch (fuzzed_data_provider.ConsumeIntegralInRange(0, 4)) { + case 0: { + const std::vector<unsigned char> key = ConsumeFixedLengthByteVector(fuzzed_data_provider, fuzzed_data_provider.ConsumeIntegralInRange<size_t>(16, 32)); + chacha20.SetKey(key.data(), key.size()); + break; + } + case 1: { + chacha20.SetIV(fuzzed_data_provider.ConsumeIntegral<uint64_t>()); + break; + } + case 2: { + chacha20.Seek(fuzzed_data_provider.ConsumeIntegral<uint64_t>()); + break; + } + case 3: { + std::vector<uint8_t> output(fuzzed_data_provider.ConsumeIntegralInRange<size_t>(0, 4096)); + chacha20.Keystream(output.data(), output.size()); + break; + } + case 4: { + std::vector<uint8_t> output(fuzzed_data_provider.ConsumeIntegralInRange<size_t>(0, 4096)); + const std::vector<uint8_t> input = ConsumeFixedLengthByteVector(fuzzed_data_provider, output.size()); + chacha20.Crypt(input.data(), output.data(), input.size()); + break; + } + } + } +} diff --git a/src/test/fuzz/crypto_chacha20_poly1305_aead.cpp b/src/test/fuzz/crypto_chacha20_poly1305_aead.cpp new file mode 100644 index 0000000000..48e4263f27 --- /dev/null +++ b/src/test/fuzz/crypto_chacha20_poly1305_aead.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <crypto/chacha_poly_aead.h> +#include <crypto/poly1305.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <cassert> +#include <cstdint> +#include <limits> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + + const std::vector<uint8_t> k1 = ConsumeFixedLengthByteVector(fuzzed_data_provider, CHACHA20_POLY1305_AEAD_KEY_LEN); + const std::vector<uint8_t> k2 = ConsumeFixedLengthByteVector(fuzzed_data_provider, CHACHA20_POLY1305_AEAD_KEY_LEN); + + ChaCha20Poly1305AEAD aead(k1.data(), k1.size(), k2.data(), k2.size()); + uint64_t seqnr_payload = 0; + uint64_t seqnr_aad = 0; + int aad_pos = 0; + size_t buffer_size = fuzzed_data_provider.ConsumeIntegralInRange<size_t>(0, 4096); + std::vector<uint8_t> in(buffer_size + CHACHA20_POLY1305_AEAD_AAD_LEN + POLY1305_TAGLEN, 0); + std::vector<uint8_t> out(buffer_size + CHACHA20_POLY1305_AEAD_AAD_LEN + POLY1305_TAGLEN, 0); + bool is_encrypt = fuzzed_data_provider.ConsumeBool(); + while (fuzzed_data_provider.ConsumeBool()) { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 6)) { + case 0: { + buffer_size = fuzzed_data_provider.ConsumeIntegralInRange<size_t>(64, 4096); + in = std::vector<uint8_t>(buffer_size + CHACHA20_POLY1305_AEAD_AAD_LEN + POLY1305_TAGLEN, 0); + out = std::vector<uint8_t>(buffer_size + CHACHA20_POLY1305_AEAD_AAD_LEN + POLY1305_TAGLEN, 0); + break; + } + case 1: { + (void)aead.Crypt(seqnr_payload, seqnr_aad, aad_pos, out.data(), out.size(), in.data(), buffer_size, is_encrypt); + break; + } + case 2: { + uint32_t len = 0; + const bool ok = aead.GetLength(&len, seqnr_aad, aad_pos, in.data()); + assert(ok); + break; + } + case 3: { + seqnr_payload += 1; + aad_pos += CHACHA20_POLY1305_AEAD_AAD_LEN; + if (aad_pos + CHACHA20_POLY1305_AEAD_AAD_LEN > CHACHA20_ROUND_OUTPUT) { + aad_pos = 0; + seqnr_aad += 1; + } + break; + } + case 4: { + seqnr_payload = fuzzed_data_provider.ConsumeIntegral<int>(); + break; + } + case 5: { + seqnr_aad = fuzzed_data_provider.ConsumeIntegral<int>(); + break; + } + case 6: { + is_encrypt = fuzzed_data_provider.ConsumeBool(); + break; + } + } + } +} diff --git a/src/test/fuzz/crypto_common.cpp b/src/test/fuzz/crypto_common.cpp new file mode 100644 index 0000000000..7ccb125216 --- /dev/null +++ b/src/test/fuzz/crypto_common.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <crypto/common.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <array> +#include <cassert> +#include <cstdint> +#include <cstring> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + const uint16_t random_u16 = fuzzed_data_provider.ConsumeIntegral<uint16_t>(); + const uint32_t random_u32 = fuzzed_data_provider.ConsumeIntegral<uint32_t>(); + const uint64_t random_u64 = fuzzed_data_provider.ConsumeIntegral<uint64_t>(); + const std::vector<uint8_t> random_bytes_2 = ConsumeFixedLengthByteVector(fuzzed_data_provider, 2); + const std::vector<uint8_t> random_bytes_4 = ConsumeFixedLengthByteVector(fuzzed_data_provider, 4); + const std::vector<uint8_t> random_bytes_8 = ConsumeFixedLengthByteVector(fuzzed_data_provider, 8); + + std::array<uint8_t, 2> writele16_arr; + WriteLE16(writele16_arr.data(), random_u16); + assert(ReadLE16(writele16_arr.data()) == random_u16); + + std::array<uint8_t, 4> writele32_arr; + WriteLE32(writele32_arr.data(), random_u32); + assert(ReadLE32(writele32_arr.data()) == random_u32); + + std::array<uint8_t, 8> writele64_arr; + WriteLE64(writele64_arr.data(), random_u64); + assert(ReadLE64(writele64_arr.data()) == random_u64); + + std::array<uint8_t, 4> writebe32_arr; + WriteBE32(writebe32_arr.data(), random_u32); + assert(ReadBE32(writebe32_arr.data()) == random_u32); + + std::array<uint8_t, 8> writebe64_arr; + WriteBE64(writebe64_arr.data(), random_u64); + assert(ReadBE64(writebe64_arr.data()) == random_u64); + + const uint16_t readle16_result = ReadLE16(random_bytes_2.data()); + std::array<uint8_t, 2> readle16_arr; + WriteLE16(readle16_arr.data(), readle16_result); + assert(std::memcmp(random_bytes_2.data(), readle16_arr.data(), 2) == 0); + + const uint32_t readle32_result = ReadLE32(random_bytes_4.data()); + std::array<uint8_t, 4> readle32_arr; + WriteLE32(readle32_arr.data(), readle32_result); + assert(std::memcmp(random_bytes_4.data(), readle32_arr.data(), 4) == 0); + + const uint64_t readle64_result = ReadLE64(random_bytes_8.data()); + std::array<uint8_t, 8> readle64_arr; + WriteLE64(readle64_arr.data(), readle64_result); + assert(std::memcmp(random_bytes_8.data(), readle64_arr.data(), 8) == 0); + + const uint32_t readbe32_result = ReadBE32(random_bytes_4.data()); + std::array<uint8_t, 4> readbe32_arr; + WriteBE32(readbe32_arr.data(), readbe32_result); + assert(std::memcmp(random_bytes_4.data(), readbe32_arr.data(), 4) == 0); + + const uint64_t readbe64_result = ReadBE64(random_bytes_8.data()); + std::array<uint8_t, 8> readbe64_arr; + WriteBE64(readbe64_arr.data(), readbe64_result); + assert(std::memcmp(random_bytes_8.data(), readbe64_arr.data(), 8) == 0); +} diff --git a/src/test/fuzz/crypto_hkdf_hmac_sha256_l32.cpp b/src/test/fuzz/crypto_hkdf_hmac_sha256_l32.cpp new file mode 100644 index 0000000000..e0a4e90c10 --- /dev/null +++ b/src/test/fuzz/crypto_hkdf_hmac_sha256_l32.cpp @@ -0,0 +1,25 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <crypto/hkdf_sha256_32.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <cstdint> +#include <string> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + + const std::vector<uint8_t> initial_key_material = ConsumeRandomLengthByteVector(fuzzed_data_provider); + + CHKDF_HMAC_SHA256_L32 hkdf_hmac_sha256_l32(initial_key_material.data(), initial_key_material.size(), fuzzed_data_provider.ConsumeRandomLengthString(1024)); + while (fuzzed_data_provider.ConsumeBool()) { + std::vector<uint8_t> out(32); + hkdf_hmac_sha256_l32.Expand32(fuzzed_data_provider.ConsumeRandomLengthString(128), out.data()); + } +} diff --git a/src/test/fuzz/crypto_poly1305.cpp b/src/test/fuzz/crypto_poly1305.cpp new file mode 100644 index 0000000000..5681e6a693 --- /dev/null +++ b/src/test/fuzz/crypto_poly1305.cpp @@ -0,0 +1,22 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <crypto/poly1305.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <cstdint> +#include <vector> + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + + const std::vector<uint8_t> key = ConsumeFixedLengthByteVector(fuzzed_data_provider, POLY1305_KEYLEN); + const std::vector<uint8_t> in = ConsumeRandomLengthByteVector(fuzzed_data_provider); + + std::vector<uint8_t> tag_out(POLY1305_TAGLEN); + poly1305_auth(tag_out.data(), in.data(), in.size(), key.data()); +} diff --git a/src/test/fuzz/decode_tx.cpp b/src/test/fuzz/decode_tx.cpp index 09c4ff05df..0d89d4228a 100644 --- a/src/test/fuzz/decode_tx.cpp +++ b/src/test/fuzz/decode_tx.cpp @@ -14,7 +14,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) { - const std::string tx_hex = HexStr(std::string{buffer.begin(), buffer.end()}); + const std::string tx_hex = HexStr(buffer); CMutableTransaction mtx; const bool result_none = DecodeHexTx(mtx, tx_hex, false, false); const bool result_try_witness = DecodeHexTx(mtx, tx_hex, false, true); diff --git a/src/test/fuzz/fuzz.cpp b/src/test/fuzz/fuzz.cpp index 82e1d55c0b..1e1807d734 100644 --- a/src/test/fuzz/fuzz.cpp +++ b/src/test/fuzz/fuzz.cpp @@ -12,7 +12,16 @@ const std::function<void(const std::string&)> G_TEST_LOG_FUN{}; -#if defined(__AFL_COMPILER) +// Decide if main(...) should be provided: +// * AFL needs main(...) regardless of platform. +// * macOS handles __attribute__((weak)) main(...) poorly when linking +// against libFuzzer. See https://github.com/bitcoin/bitcoin/pull/18008 +// for details. +#if defined(__AFL_COMPILER) || !defined(MAC_OSX) +#define PROVIDE_MAIN_FUNCTION +#endif + +#if defined(PROVIDE_MAIN_FUNCTION) static bool read_stdin(std::vector<uint8_t>& data) { uint8_t buffer[1024]; @@ -44,9 +53,8 @@ extern "C" int LLVMFuzzerInitialize(int* argc, char*** argv) return 0; } -// Generally, the fuzzer will provide main(), except for AFL -#if defined(__AFL_COMPILER) -int main(int argc, char** argv) +#if defined(PROVIDE_MAIN_FUNCTION) +__attribute__((weak)) int main(int argc, char** argv) { initialize(); #ifdef __AFL_INIT diff --git a/src/test/fuzz/http_request.cpp b/src/test/fuzz/http_request.cpp index ebf89749e9..36d44e361f 100644 --- a/src/test/fuzz/http_request.cpp +++ b/src/test/fuzz/http_request.cpp @@ -7,6 +7,7 @@ #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> +#include <util/strencodings.h> #include <event2/buffer.h> #include <event2/event.h> @@ -48,7 +49,14 @@ void test_one_input(const std::vector<uint8_t>& buffer) assert(evbuf != nullptr); const std::vector<uint8_t> http_buffer = ConsumeRandomLengthByteVector(fuzzed_data_provider, 4096); evbuffer_add(evbuf, http_buffer.data(), http_buffer.size()); - if (evhttp_parse_firstline_(evreq, evbuf) != 1 || evhttp_parse_headers_(evreq, evbuf) != 1) { + // Avoid constructing requests that will be interpreted by libevent as PROXY requests to avoid triggering + // a nullptr dereference. The dereference (req->evcon->http_server) takes place in evhttp_parse_request_line + // and is a consequence of our hacky but necessary use of the internal function evhttp_parse_firstline_ in + // this fuzzing harness. The workaround is not aesthetically pleasing, but it successfully avoids the troublesome + // code path. " http:// HTTP/1.1\n" was a crashing input prior to this workaround. + const std::string http_buffer_str = ToLower({http_buffer.begin(), http_buffer.end()}); + if (http_buffer_str.find(" http://") != std::string::npos || http_buffer_str.find(" https://") != std::string::npos || + evhttp_parse_firstline_(evreq, evbuf) != 1 || evhttp_parse_headers_(evreq, evbuf) != 1) { evbuffer_free(evbuf); evhttp_request_free(evreq); return; diff --git a/src/test/fuzz/key.cpp b/src/test/fuzz/key.cpp index 1919a5f881..955b954700 100644 --- a/src/test/fuzz/key.cpp +++ b/src/test/fuzz/key.cpp @@ -85,7 +85,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) assert(negated_key == key); } - const uint256 random_uint256 = Hash(buffer.begin(), buffer.end()); + const uint256 random_uint256 = Hash(buffer); { CKey child_key; @@ -108,7 +108,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) assert(pubkey.IsCompressed()); assert(pubkey.IsValid()); assert(pubkey.IsFullyValid()); - assert(HexToPubKey(HexStr(pubkey.begin(), pubkey.end())) == pubkey); + assert(HexToPubKey(HexStr(pubkey)) == pubkey); assert(GetAllDestinationsForKey(pubkey).size() == 3); } @@ -157,25 +157,25 @@ void test_one_input(const std::vector<uint8_t>& buffer) assert(ok_add_key_pubkey); assert(fillable_signing_provider_pub.HaveKey(pubkey.GetID())); - txnouttype which_type_tx_pubkey; + TxoutType which_type_tx_pubkey; const bool is_standard_tx_pubkey = IsStandard(tx_pubkey_script, which_type_tx_pubkey); assert(is_standard_tx_pubkey); - assert(which_type_tx_pubkey == txnouttype::TX_PUBKEY); + assert(which_type_tx_pubkey == TxoutType::PUBKEY); - txnouttype which_type_tx_multisig; + TxoutType which_type_tx_multisig; const bool is_standard_tx_multisig = IsStandard(tx_multisig_script, which_type_tx_multisig); assert(is_standard_tx_multisig); - assert(which_type_tx_multisig == txnouttype::TX_MULTISIG); + assert(which_type_tx_multisig == TxoutType::MULTISIG); std::vector<std::vector<unsigned char>> v_solutions_ret_tx_pubkey; - const txnouttype outtype_tx_pubkey = Solver(tx_pubkey_script, v_solutions_ret_tx_pubkey); - assert(outtype_tx_pubkey == txnouttype::TX_PUBKEY); + const TxoutType outtype_tx_pubkey = Solver(tx_pubkey_script, v_solutions_ret_tx_pubkey); + assert(outtype_tx_pubkey == TxoutType::PUBKEY); assert(v_solutions_ret_tx_pubkey.size() == 1); assert(v_solutions_ret_tx_pubkey[0].size() == 33); std::vector<std::vector<unsigned char>> v_solutions_ret_tx_multisig; - const txnouttype outtype_tx_multisig = Solver(tx_multisig_script, v_solutions_ret_tx_multisig); - assert(outtype_tx_multisig == txnouttype::TX_MULTISIG); + const TxoutType outtype_tx_multisig = Solver(tx_multisig_script, v_solutions_ret_tx_multisig); + assert(outtype_tx_multisig == TxoutType::MULTISIG); assert(v_solutions_ret_tx_multisig.size() == 3); assert(v_solutions_ret_tx_multisig[0].size() == 1); assert(v_solutions_ret_tx_multisig[1].size() == 33); diff --git a/src/test/fuzz/kitchen_sink.cpp b/src/test/fuzz/kitchen_sink.cpp index af6dc71322..82cbc00a3a 100644 --- a/src/test/fuzz/kitchen_sink.cpp +++ b/src/test/fuzz/kitchen_sink.cpp @@ -7,6 +7,7 @@ #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> #include <util/error.h> +#include <util/translation.h> #include <cstdint> #include <vector> diff --git a/src/test/fuzz/load_external_block_file.cpp b/src/test/fuzz/load_external_block_file.cpp new file mode 100644 index 0000000000..d9de9d9866 --- /dev/null +++ b/src/test/fuzz/load_external_block_file.cpp @@ -0,0 +1,31 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <chainparams.h> +#include <flatfile.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> +#include <test/util/setup_common.h> +#include <validation.h> + +#include <cstdint> +#include <vector> + +void initialize() +{ + InitializeFuzzingContext(); +} + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; + FuzzedFileProvider fuzzed_file_provider = ConsumeFile(fuzzed_data_provider); + FILE* fuzzed_block_file = fuzzed_file_provider.open(); + if (fuzzed_block_file == nullptr) { + return; + } + FlatFilePos flat_file_pos; + LoadExternalBlockFile(Params(), fuzzed_block_file, fuzzed_data_provider.ConsumeBool() ? &flat_file_pos : nullptr); +} diff --git a/src/test/fuzz/net_permissions.cpp b/src/test/fuzz/net_permissions.cpp index c071283467..8a674ac1e9 100644 --- a/src/test/fuzz/net_permissions.cpp +++ b/src/test/fuzz/net_permissions.cpp @@ -6,6 +6,7 @@ #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/fuzz/util.h> +#include <util/translation.h> #include <cassert> #include <cstdint> @@ -23,13 +24,14 @@ void test_one_input(const std::vector<uint8_t>& buffer) NetPermissionFlags::PF_FORCERELAY, NetPermissionFlags::PF_NOBAN, NetPermissionFlags::PF_MEMPOOL, + NetPermissionFlags::PF_ADDR, NetPermissionFlags::PF_ISIMPLICIT, NetPermissionFlags::PF_ALL, }) : static_cast<NetPermissionFlags>(fuzzed_data_provider.ConsumeIntegral<uint32_t>()); NetWhitebindPermissions net_whitebind_permissions; - std::string error_net_whitebind_permissions; + bilingual_str error_net_whitebind_permissions; if (NetWhitebindPermissions::TryParse(s, net_whitebind_permissions, error_net_whitebind_permissions)) { (void)NetPermissions::ToStrings(net_whitebind_permissions.m_flags); (void)NetPermissions::AddFlag(net_whitebind_permissions.m_flags, net_permission_flags); @@ -39,7 +41,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) } NetWhitelistPermissions net_whitelist_permissions; - std::string error_net_whitelist_permissions; + bilingual_str error_net_whitelist_permissions; if (NetWhitelistPermissions::TryParse(s, net_whitelist_permissions, error_net_whitelist_permissions)) { (void)NetPermissions::ToStrings(net_whitelist_permissions.m_flags); (void)NetPermissions::AddFlag(net_whitelist_permissions.m_flags, net_permission_flags); diff --git a/src/test/fuzz/netaddress.cpp b/src/test/fuzz/netaddress.cpp index d8d53566c7..2901c704f6 100644 --- a/src/test/fuzz/netaddress.cpp +++ b/src/test/fuzz/netaddress.cpp @@ -5,41 +5,13 @@ #include <netaddress.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> #include <cassert> #include <cstdint> #include <netinet/in.h> #include <vector> -namespace { -CNetAddr ConsumeNetAddr(FuzzedDataProvider& fuzzed_data_provider) noexcept -{ - const Network network = fuzzed_data_provider.PickValueInArray({Network::NET_IPV4, Network::NET_IPV6, Network::NET_INTERNAL, Network::NET_ONION}); - if (network == Network::NET_IPV4) { - const in_addr v4_addr = { - .s_addr = fuzzed_data_provider.ConsumeIntegral<uint32_t>()}; - return CNetAddr{v4_addr}; - } else if (network == Network::NET_IPV6) { - if (fuzzed_data_provider.remaining_bytes() < 16) { - return CNetAddr{}; - } - in6_addr v6_addr = {}; - memcpy(v6_addr.s6_addr, fuzzed_data_provider.ConsumeBytes<uint8_t>(16).data(), 16); - return CNetAddr{v6_addr, fuzzed_data_provider.ConsumeIntegral<uint32_t>()}; - } else if (network == Network::NET_INTERNAL) { - CNetAddr net_addr; - net_addr.SetInternal(fuzzed_data_provider.ConsumeBytesAsString(32)); - return net_addr; - } else if (network == Network::NET_ONION) { - CNetAddr net_addr; - net_addr.SetSpecial(fuzzed_data_provider.ConsumeBytesAsString(32)); - return net_addr; - } else { - assert(false); - } -} -}; // namespace - void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp index 57393fed45..6fba2bfaba 100644 --- a/src/test/fuzz/p2p_transport_deserializer.cpp +++ b/src/test/fuzz/p2p_transport_deserializer.cpp @@ -30,7 +30,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) pch += handled; n_bytes -= handled; if (deserializer.Complete()) { - const int64_t m_time = std::numeric_limits<int64_t>::max(); + const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()}; const CNetMessage msg = deserializer.GetMessage(Params().MessageStart(), m_time); assert(msg.m_command.size() <= CMessageHeader::COMMAND_SIZE); assert(msg.m_raw_message_size <= buffer.size()); diff --git a/src/test/fuzz/policy_estimator.cpp b/src/test/fuzz/policy_estimator.cpp index 1cbf9b347f..6c94a47f3c 100644 --- a/src/test/fuzz/policy_estimator.cpp +++ b/src/test/fuzz/policy_estimator.cpp @@ -14,6 +14,11 @@ #include <string> #include <vector> +void initialize() +{ + InitializeFuzzingContext(); +} + void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); @@ -66,4 +71,10 @@ void test_one_input(const std::vector<uint8_t>& buffer) (void)block_policy_estimator.estimateSmartFee(fuzzed_data_provider.ConsumeIntegral<int>(), fuzzed_data_provider.ConsumeBool() ? &fee_calculation : nullptr, fuzzed_data_provider.ConsumeBool()); (void)block_policy_estimator.HighestTargetTracked(fuzzed_data_provider.PickValueInArray({FeeEstimateHorizon::SHORT_HALFLIFE, FeeEstimateHorizon::MED_HALFLIFE, FeeEstimateHorizon::LONG_HALFLIFE})); } + { + FuzzedAutoFileProvider fuzzed_auto_file_provider = ConsumeAutoFile(fuzzed_data_provider); + CAutoFile fuzzed_auto_file = fuzzed_auto_file_provider.open(); + block_policy_estimator.Write(fuzzed_auto_file); + block_policy_estimator.Read(fuzzed_auto_file); + } } diff --git a/src/test/fuzz/policy_estimator_io.cpp b/src/test/fuzz/policy_estimator_io.cpp new file mode 100644 index 0000000000..0edcf201c7 --- /dev/null +++ b/src/test/fuzz/policy_estimator_io.cpp @@ -0,0 +1,28 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <policy/fees.h> +#include <test/fuzz/FuzzedDataProvider.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> + +#include <cstdint> +#include <vector> + +void initialize() +{ + InitializeFuzzingContext(); +} + +void test_one_input(const std::vector<uint8_t>& buffer) +{ + FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); + FuzzedAutoFileProvider fuzzed_auto_file_provider = ConsumeAutoFile(fuzzed_data_provider); + CAutoFile fuzzed_auto_file = fuzzed_auto_file_provider.open(); + // Re-using block_policy_estimator across runs to avoid costly creation of CBlockPolicyEstimator object. + static CBlockPolicyEstimator block_policy_estimator; + if (block_policy_estimator.Read(fuzzed_auto_file)) { + block_policy_estimator.Write(fuzzed_auto_file); + } +} diff --git a/src/test/fuzz/process_message.cpp b/src/test/fuzz/process_message.cpp index 665a6224b4..677b87a47a 100644 --- a/src/test/fuzz/process_message.cpp +++ b/src/test/fuzz/process_message.cpp @@ -14,6 +14,7 @@ #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> #include <test/util/mining.h> +#include <test/util/net.h> #include <test/util/setup_common.h> #include <util/memory.h> #include <validationinterface.h> @@ -29,7 +30,17 @@ #include <string> #include <vector> -bool ProcessMessage(CNode* pfrom, const std::string& msg_type, CDataStream& vRecv, int64_t nTimeReceived, const CChainParams& chainparams, ChainstateManager& chainman, CTxMemPool& mempool, CConnman* connman, BanMan* banman, const std::atomic<bool>& interruptMsgProc); +void ProcessMessage( + CNode& pfrom, + const std::string& msg_type, + CDataStream& vRecv, + const std::chrono::microseconds time_received, + const CChainParams& chainparams, + ChainstateManager& chainman, + CTxMemPool& mempool, + CConnman& connman, + BanMan* banman, + const std::atomic<bool>& interruptMsgProc); namespace { @@ -63,19 +74,26 @@ void initialize() void test_one_input(const std::vector<uint8_t>& buffer) { FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); + ConnmanTestMsg& connman = *(ConnmanTestMsg*)g_setup->m_node.connman.get(); const std::string random_message_type{fuzzed_data_provider.ConsumeBytesAsString(CMessageHeader::COMMAND_SIZE).c_str()}; if (!LIMIT_TO_MESSAGE_TYPE.empty() && random_message_type != LIMIT_TO_MESSAGE_TYPE) { return; } CDataStream random_bytes_data_stream{fuzzed_data_provider.ConsumeRemainingBytes<unsigned char>(), SER_NETWORK, PROTOCOL_VERSION}; - CNode p2p_node{0, ServiceFlags(NODE_NETWORK | NODE_WITNESS | NODE_BLOOM), 0, INVALID_SOCKET, CAddress{CService{in_addr{0x0100007f}, 7777}, NODE_NETWORK}, 0, 0, CAddress{}, std::string{}, false}; + CNode& p2p_node = *MakeUnique<CNode>(0, ServiceFlags(NODE_NETWORK | NODE_WITNESS | NODE_BLOOM), 0, INVALID_SOCKET, CAddress{CService{in_addr{0x0100007f}, 7777}, NODE_NETWORK}, 0, 0, CAddress{}, std::string{}, ConnectionType::OUTBOUND).release(); p2p_node.fSuccessfullyConnected = true; p2p_node.nVersion = PROTOCOL_VERSION; p2p_node.SetSendVersion(PROTOCOL_VERSION); + connman.AddTestNode(p2p_node); g_setup->m_node.peer_logic->InitializeNode(&p2p_node); try { - (void)ProcessMessage(&p2p_node, random_message_type, random_bytes_data_stream, GetTimeMillis(), Params(), *g_setup->m_node.chainman, *g_setup->m_node.mempool, g_setup->m_node.connman.get(), g_setup->m_node.banman.get(), std::atomic<bool>{false}); + ProcessMessage(p2p_node, random_message_type, random_bytes_data_stream, GetTime<std::chrono::microseconds>(), + Params(), *g_setup->m_node.chainman, *g_setup->m_node.mempool, + *g_setup->m_node.connman, g_setup->m_node.banman.get(), + std::atomic<bool>{false}); } catch (const std::ios_base::failure&) { } SyncWithValidationInterfaceQueue(); + LOCK2(::cs_main, g_cs_orphans); // See init.cpp for rationale for implicit locking order requirement + g_setup->m_node.connman->StopNodes(); } diff --git a/src/test/fuzz/process_messages.cpp b/src/test/fuzz/process_messages.cpp index bcbf65bdca..ef427442e9 100644 --- a/src/test/fuzz/process_messages.cpp +++ b/src/test/fuzz/process_messages.cpp @@ -44,9 +44,8 @@ void test_one_input(const std::vector<uint8_t>& buffer) const auto num_peers_to_add = fuzzed_data_provider.ConsumeIntegralInRange(1, 3); for (int i = 0; i < num_peers_to_add; ++i) { const ServiceFlags service_flags = ServiceFlags(fuzzed_data_provider.ConsumeIntegral<uint64_t>()); - const bool inbound{fuzzed_data_provider.ConsumeBool()}; - const bool block_relay_only{fuzzed_data_provider.ConsumeBool()}; - peers.push_back(MakeUnique<CNode>(i, service_flags, 0, INVALID_SOCKET, CAddress{CService{in_addr{0x0100007f}, 7777}, NODE_NETWORK}, 0, 0, CAddress{}, std::string{}, inbound, block_relay_only).release()); + const ConnectionType conn_type = fuzzed_data_provider.PickValueInArray({ConnectionType::INBOUND, ConnectionType::OUTBOUND, ConnectionType::MANUAL, ConnectionType::FEELER, ConnectionType::BLOCK_RELAY, ConnectionType::ADDR_FETCH}); + peers.push_back(MakeUnique<CNode>(i, service_flags, 0, INVALID_SOCKET, CAddress{CService{in_addr{0x0100007f}, 7777}, NODE_NETWORK}, 0, 0, CAddress{}, std::string{}, conn_type).release()); CNode& p2p_node = *peers.back(); p2p_node.fSuccessfullyConnected = true; @@ -62,7 +61,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) const std::string random_message_type{fuzzed_data_provider.ConsumeBytesAsString(CMessageHeader::COMMAND_SIZE).c_str()}; CSerializedNetMsg net_msg; - net_msg.command = random_message_type; + net_msg.m_type = random_message_type; net_msg.data = ConsumeRandomLengthByteVector(fuzzed_data_provider); CNode& random_node = *peers.at(fuzzed_data_provider.ConsumeIntegralInRange<int>(0, peers.size() - 1)); @@ -75,6 +74,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) } catch (const std::ios_base::failure&) { } } - connman.ClearTestNodes(); SyncWithValidationInterfaceQueue(); + LOCK2(::cs_main, g_cs_orphans); // See init.cpp for rationale for implicit locking order requirement + g_setup->m_node.connman->StopNodes(); } diff --git a/src/test/fuzz/psbt.cpp b/src/test/fuzz/psbt.cpp index 64328fb66e..908e2b16f2 100644 --- a/src/test/fuzz/psbt.cpp +++ b/src/test/fuzz/psbt.cpp @@ -39,7 +39,6 @@ void test_one_input(const std::vector<uint8_t>& buffer) } (void)psbt.IsNull(); - (void)psbt.IsSane(); Optional<CMutableTransaction> tx = psbt.tx; if (tx) { @@ -50,7 +49,6 @@ void test_one_input(const std::vector<uint8_t>& buffer) for (const PSBTInput& input : psbt.inputs) { (void)PSBTInputSigned(input); (void)input.IsNull(); - (void)input.IsSane(); } for (const PSBTOutput& output : psbt.outputs) { diff --git a/src/test/fuzz/script.cpp b/src/test/fuzz/script.cpp index 933cf9049d..85aac6ac7a 100644 --- a/src/test/fuzz/script.cpp +++ b/src/test/fuzz/script.cpp @@ -48,7 +48,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) if (CompressScript(script, compressed)) { const unsigned int size = compressed[0]; compressed.erase(compressed.begin()); - assert(size >= 0 && size <= 5); + assert(size <= 5); CScript decompressed_script; const bool ok = DecompressScript(decompressed_script, size, compressed); assert(ok); @@ -58,7 +58,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) CTxDestination address; (void)ExtractDestination(script, address); - txnouttype type_ret; + TxoutType type_ret; std::vector<CTxDestination> addresses; int required_ret; (void)ExtractDestinations(script, type_ret, addresses, required_ret); @@ -72,7 +72,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) (void)IsSolvable(signing_provider, script); - txnouttype which_type; + TxoutType which_type; (void)IsStandard(script, which_type); (void)RecursiveDynamicUsage(script); diff --git a/src/test/fuzz/scriptnum_ops.cpp b/src/test/fuzz/scriptnum_ops.cpp index f4e079fb89..68c1ae58ca 100644 --- a/src/test/fuzz/scriptnum_ops.cpp +++ b/src/test/fuzz/scriptnum_ops.cpp @@ -33,7 +33,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) case 0: { const int64_t i = fuzzed_data_provider.ConsumeIntegral<int64_t>(); assert((script_num == i) != (script_num != i)); - assert((script_num <= i) != script_num > i); + assert((script_num <= i) != (script_num > i)); assert((script_num >= i) != (script_num < i)); // Avoid signed integer overflow: // script/script.h:264:93: runtime error: signed integer overflow: -2261405121394637306 + -9223372036854775802 cannot be represented in type 'long' diff --git a/src/test/fuzz/signature_checker.cpp b/src/test/fuzz/signature_checker.cpp index 4a8c7a63af..3aaeb66649 100644 --- a/src/test/fuzz/signature_checker.cpp +++ b/src/test/fuzz/signature_checker.cpp @@ -28,17 +28,17 @@ public: { } - virtual bool CheckSig(const std::vector<unsigned char>& scriptSig, const std::vector<unsigned char>& vchPubKey, const CScript& scriptCode, SigVersion sigversion) const + bool CheckSig(const std::vector<unsigned char>& scriptSig, const std::vector<unsigned char>& vchPubKey, const CScript& scriptCode, SigVersion sigversion) const override { return m_fuzzed_data_provider.ConsumeBool(); } - virtual bool CheckLockTime(const CScriptNum& nLockTime) const + bool CheckLockTime(const CScriptNum& nLockTime) const override { return m_fuzzed_data_provider.ConsumeBool(); } - virtual bool CheckSequence(const CScriptNum& nSequence) const + bool CheckSequence(const CScriptNum& nSequence) const override { return m_fuzzed_data_provider.ConsumeBool(); } diff --git a/src/test/fuzz/span.cpp b/src/test/fuzz/span.cpp index 4aea530ef2..f6b6e8f6f0 100644 --- a/src/test/fuzz/span.cpp +++ b/src/test/fuzz/span.cpp @@ -18,7 +18,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size()); std::string str = fuzzed_data_provider.ConsumeBytesAsString(32); - const Span<const char> span = MakeSpan(str); + const Span<const char> span{str}; (void)span.data(); (void)span.begin(); (void)span.end(); @@ -32,7 +32,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) } std::string another_str = fuzzed_data_provider.ConsumeBytesAsString(32); - const Span<const char> another_span = MakeSpan(another_str); + const Span<const char> another_span{another_str}; assert((span <= another_span) != (span > another_span)); assert((span == another_span) != (span != another_span)); assert((span >= another_span) != (span < another_span)); diff --git a/src/test/fuzz/spanparsing.cpp b/src/test/fuzz/spanparsing.cpp index 8e5e7dad11..e5bf5dd608 100644 --- a/src/test/fuzz/spanparsing.cpp +++ b/src/test/fuzz/spanparsing.cpp @@ -12,7 +12,7 @@ void test_one_input(const std::vector<uint8_t>& buffer) const size_t query_size = fuzzed_data_provider.ConsumeIntegral<size_t>(); const std::string query = fuzzed_data_provider.ConsumeBytesAsString(std::min<size_t>(query_size, 1024 * 1024)); const std::string span_str = fuzzed_data_provider.ConsumeRemainingBytesAsString(); - const Span<const char> const_span = MakeSpan(span_str); + const Span<const char> const_span{span_str}; Span<const char> mut_span = const_span; (void)spanparsing::Const(query, mut_span); diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index f26878a704..9f9552edb9 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -8,8 +8,11 @@ #include <amount.h> #include <arith_uint256.h> #include <attributes.h> +#include <chainparamsbase.h> #include <coins.h> #include <consensus/consensus.h> +#include <netaddress.h> +#include <netbase.h> #include <primitives/transaction.h> #include <script/script.h> #include <script/standard.h> @@ -17,12 +20,14 @@ #include <streams.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> +#include <test/util/setup_common.h> #include <txmempool.h> #include <uint256.h> #include <version.h> #include <algorithm> #include <cstdint> +#include <cstdio> #include <optional> #include <string> #include <vector> @@ -214,4 +219,255 @@ NODISCARD inline bool ContainsSpentInput(const CTransaction& tx, const CCoinsVie return false; } +/** + * Returns a byte vector of specified size regardless of the number of remaining bytes available + * from the fuzzer. Pads with zero value bytes if needed to achieve the specified size. + */ +NODISCARD inline std::vector<uint8_t> ConsumeFixedLengthByteVector(FuzzedDataProvider& fuzzed_data_provider, const size_t length) noexcept +{ + std::vector<uint8_t> result(length); + const std::vector<uint8_t> random_bytes = fuzzed_data_provider.ConsumeBytes<uint8_t>(length); + if (!random_bytes.empty()) { + std::memcpy(result.data(), random_bytes.data(), random_bytes.size()); + } + return result; +} + +CNetAddr ConsumeNetAddr(FuzzedDataProvider& fuzzed_data_provider) noexcept +{ + const Network network = fuzzed_data_provider.PickValueInArray({Network::NET_IPV4, Network::NET_IPV6, Network::NET_INTERNAL, Network::NET_ONION}); + CNetAddr net_addr; + if (network == Network::NET_IPV4) { + const in_addr v4_addr = { + .s_addr = fuzzed_data_provider.ConsumeIntegral<uint32_t>()}; + net_addr = CNetAddr{v4_addr}; + } else if (network == Network::NET_IPV6) { + if (fuzzed_data_provider.remaining_bytes() >= 16) { + in6_addr v6_addr = {}; + memcpy(v6_addr.s6_addr, fuzzed_data_provider.ConsumeBytes<uint8_t>(16).data(), 16); + net_addr = CNetAddr{v6_addr, fuzzed_data_provider.ConsumeIntegral<uint32_t>()}; + } + } else if (network == Network::NET_INTERNAL) { + net_addr.SetInternal(fuzzed_data_provider.ConsumeBytesAsString(32)); + } else if (network == Network::NET_ONION) { + net_addr.SetSpecial(fuzzed_data_provider.ConsumeBytesAsString(32)); + } + return net_addr; +} + +CSubNet ConsumeSubNet(FuzzedDataProvider& fuzzed_data_provider) noexcept +{ + return {ConsumeNetAddr(fuzzed_data_provider), fuzzed_data_provider.ConsumeIntegral<int32_t>()}; +} + +void InitializeFuzzingContext(const std::string& chain_name = CBaseChainParams::REGTEST) +{ + static const BasicTestingSetup basic_testing_setup{chain_name, {"-nodebuglogfile"}}; +} + +class FuzzedFileProvider +{ + FuzzedDataProvider& m_fuzzed_data_provider; + int64_t m_offset = 0; + +public: + FuzzedFileProvider(FuzzedDataProvider& fuzzed_data_provider) : m_fuzzed_data_provider{fuzzed_data_provider} + { + } + + FILE* open() + { + if (m_fuzzed_data_provider.ConsumeBool()) { + return nullptr; + } + std::string mode; + switch (m_fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 5)) { + case 0: { + mode = "r"; + break; + } + case 1: { + mode = "r+"; + break; + } + case 2: { + mode = "w"; + break; + } + case 3: { + mode = "w+"; + break; + } + case 4: { + mode = "a"; + break; + } + case 5: { + mode = "a+"; + break; + } + } +#ifdef _GNU_SOURCE + const cookie_io_functions_t io_hooks = { + FuzzedFileProvider::read, + FuzzedFileProvider::write, + FuzzedFileProvider::seek, + FuzzedFileProvider::close, + }; + return fopencookie(this, mode.c_str(), io_hooks); +#else + (void)mode; + return nullptr; +#endif + } + + static ssize_t read(void* cookie, char* buf, size_t size) + { + FuzzedFileProvider* fuzzed_file = (FuzzedFileProvider*)cookie; + if (buf == nullptr || size == 0 || fuzzed_file->m_fuzzed_data_provider.ConsumeBool()) { + return fuzzed_file->m_fuzzed_data_provider.ConsumeBool() ? 0 : -1; + } + const std::vector<uint8_t> random_bytes = fuzzed_file->m_fuzzed_data_provider.ConsumeBytes<uint8_t>(size); + if (random_bytes.empty()) { + return 0; + } + std::memcpy(buf, random_bytes.data(), random_bytes.size()); + if (AdditionOverflow(fuzzed_file->m_offset, (int64_t)random_bytes.size())) { + return fuzzed_file->m_fuzzed_data_provider.ConsumeBool() ? 0 : -1; + } + fuzzed_file->m_offset += random_bytes.size(); + return random_bytes.size(); + } + + static ssize_t write(void* cookie, const char* buf, size_t size) + { + FuzzedFileProvider* fuzzed_file = (FuzzedFileProvider*)cookie; + const ssize_t n = fuzzed_file->m_fuzzed_data_provider.ConsumeIntegralInRange<ssize_t>(0, size); + if (AdditionOverflow(fuzzed_file->m_offset, (int64_t)n)) { + return fuzzed_file->m_fuzzed_data_provider.ConsumeBool() ? 0 : -1; + } + fuzzed_file->m_offset += n; + return n; + } + + static int seek(void* cookie, int64_t* offset, int whence) + { + assert(whence == SEEK_SET || whence == SEEK_CUR); // SEEK_END not implemented yet. + FuzzedFileProvider* fuzzed_file = (FuzzedFileProvider*)cookie; + int64_t new_offset = 0; + if (whence == SEEK_SET) { + new_offset = *offset; + } else if (whence == SEEK_CUR) { + if (AdditionOverflow(fuzzed_file->m_offset, *offset)) { + return -1; + } + new_offset = fuzzed_file->m_offset + *offset; + } + if (new_offset < 0) { + return -1; + } + fuzzed_file->m_offset = new_offset; + *offset = new_offset; + return fuzzed_file->m_fuzzed_data_provider.ConsumeIntegralInRange<int>(-1, 0); + } + + static int close(void* cookie) + { + FuzzedFileProvider* fuzzed_file = (FuzzedFileProvider*)cookie; + return fuzzed_file->m_fuzzed_data_provider.ConsumeIntegralInRange<int>(-1, 0); + } +}; + +NODISCARD inline FuzzedFileProvider ConsumeFile(FuzzedDataProvider& fuzzed_data_provider) noexcept +{ + return {fuzzed_data_provider}; +} + +class FuzzedAutoFileProvider +{ + FuzzedDataProvider& m_fuzzed_data_provider; + FuzzedFileProvider m_fuzzed_file_provider; + +public: + FuzzedAutoFileProvider(FuzzedDataProvider& fuzzed_data_provider) : m_fuzzed_data_provider{fuzzed_data_provider}, m_fuzzed_file_provider{fuzzed_data_provider} + { + } + + CAutoFile open() + { + return {m_fuzzed_file_provider.open(), m_fuzzed_data_provider.ConsumeIntegral<int>(), m_fuzzed_data_provider.ConsumeIntegral<int>()}; + } +}; + +NODISCARD inline FuzzedAutoFileProvider ConsumeAutoFile(FuzzedDataProvider& fuzzed_data_provider) noexcept +{ + return {fuzzed_data_provider}; +} + +#define WRITE_TO_STREAM_CASE(id, type, consume) \ + case id: { \ + type o = consume; \ + stream << o; \ + break; \ + } +template <typename Stream> +void WriteToStream(FuzzedDataProvider& fuzzed_data_provider, Stream& stream) noexcept +{ + while (fuzzed_data_provider.ConsumeBool()) { + try { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 13)) { + WRITE_TO_STREAM_CASE(0, bool, fuzzed_data_provider.ConsumeBool()) + WRITE_TO_STREAM_CASE(1, char, fuzzed_data_provider.ConsumeIntegral<char>()) + WRITE_TO_STREAM_CASE(2, int8_t, fuzzed_data_provider.ConsumeIntegral<int8_t>()) + WRITE_TO_STREAM_CASE(3, uint8_t, fuzzed_data_provider.ConsumeIntegral<uint8_t>()) + WRITE_TO_STREAM_CASE(4, int16_t, fuzzed_data_provider.ConsumeIntegral<int16_t>()) + WRITE_TO_STREAM_CASE(5, uint16_t, fuzzed_data_provider.ConsumeIntegral<uint16_t>()) + WRITE_TO_STREAM_CASE(6, int32_t, fuzzed_data_provider.ConsumeIntegral<int32_t>()) + WRITE_TO_STREAM_CASE(7, uint32_t, fuzzed_data_provider.ConsumeIntegral<uint32_t>()) + WRITE_TO_STREAM_CASE(8, int64_t, fuzzed_data_provider.ConsumeIntegral<int64_t>()) + WRITE_TO_STREAM_CASE(9, uint64_t, fuzzed_data_provider.ConsumeIntegral<uint64_t>()) + WRITE_TO_STREAM_CASE(10, float, fuzzed_data_provider.ConsumeFloatingPoint<float>()) + WRITE_TO_STREAM_CASE(11, double, fuzzed_data_provider.ConsumeFloatingPoint<double>()) + WRITE_TO_STREAM_CASE(12, std::string, fuzzed_data_provider.ConsumeRandomLengthString(32)) + WRITE_TO_STREAM_CASE(13, std::vector<char>, ConsumeRandomLengthIntegralVector<char>(fuzzed_data_provider)) + } + } catch (const std::ios_base::failure&) { + break; + } + } +} + +#define READ_FROM_STREAM_CASE(id, type) \ + case id: { \ + type o; \ + stream >> o; \ + break; \ + } +template <typename Stream> +void ReadFromStream(FuzzedDataProvider& fuzzed_data_provider, Stream& stream) noexcept +{ + while (fuzzed_data_provider.ConsumeBool()) { + try { + switch (fuzzed_data_provider.ConsumeIntegralInRange<int>(0, 13)) { + READ_FROM_STREAM_CASE(0, bool) + READ_FROM_STREAM_CASE(1, char) + READ_FROM_STREAM_CASE(2, int8_t) + READ_FROM_STREAM_CASE(3, uint8_t) + READ_FROM_STREAM_CASE(4, int16_t) + READ_FROM_STREAM_CASE(5, uint16_t) + READ_FROM_STREAM_CASE(6, int32_t) + READ_FROM_STREAM_CASE(7, uint32_t) + READ_FROM_STREAM_CASE(8, int64_t) + READ_FROM_STREAM_CASE(9, uint64_t) + READ_FROM_STREAM_CASE(10, float) + READ_FROM_STREAM_CASE(11, double) + READ_FROM_STREAM_CASE(12, std::string) + READ_FROM_STREAM_CASE(13, std::vector<char>) + } + } catch (const std::ios_base::failure&) { + break; + } + } +} + #endif // BITCOIN_TEST_FUZZ_UTIL_H diff --git a/src/test/key_tests.cpp b/src/test/key_tests.cpp index cf2bd03698..4e4c44266a 100644 --- a/src/test/key_tests.cpp +++ b/src/test/key_tests.cpp @@ -5,6 +5,7 @@ #include <key.h> #include <key_io.h> +#include <streams.h> #include <test/util/setup_common.h> #include <uint256.h> #include <util/strencodings.h> @@ -76,7 +77,7 @@ BOOST_AUTO_TEST_CASE(key_test1) for (int n=0; n<16; n++) { std::string strMsg = strprintf("Very secret message %i: 11", n); - uint256 hashMsg = Hash(strMsg.begin(), strMsg.end()); + uint256 hashMsg = Hash(strMsg); // normal signatures @@ -133,7 +134,7 @@ BOOST_AUTO_TEST_CASE(key_test1) std::vector<unsigned char> detsig, detsigc; std::string strMsg = "Very deterministic message"; - uint256 hashMsg = Hash(strMsg.begin(), strMsg.end()); + uint256 hashMsg = Hash(strMsg); BOOST_CHECK(key1.Sign(hashMsg, detsig)); BOOST_CHECK(key1C.Sign(hashMsg, detsigc)); BOOST_CHECK(detsig == detsigc); @@ -157,7 +158,7 @@ BOOST_AUTO_TEST_CASE(key_signature_tests) // When entropy is specified, we should see at least one high R signature within 20 signatures CKey key = DecodeSecret(strSecret1); std::string msg = "A message to be signed"; - uint256 msg_hash = Hash(msg.begin(), msg.end()); + uint256 msg_hash = Hash(msg); std::vector<unsigned char> sig; bool found = false; @@ -178,7 +179,7 @@ BOOST_AUTO_TEST_CASE(key_signature_tests) for (int i = 0; i < 256; ++i) { sig.clear(); std::string msg = "A message to be signed" + ToString(i); - msg_hash = Hash(msg.begin(), msg.end()); + msg_hash = Hash(msg); BOOST_CHECK(key.Sign(msg_hash, sig)); found = sig[3] == 0x20; BOOST_CHECK(sig.size() <= 70); @@ -195,7 +196,7 @@ BOOST_AUTO_TEST_CASE(key_key_negation) std::string str = "Bitcoin key verification\n"; GetRandBytes(rnd, sizeof(rnd)); uint256 hash; - CHash256().Write((unsigned char*)str.data(), str.size()).Write(rnd, sizeof(rnd)).Finalize(hash.begin()); + CHash256().Write(MakeUCharSpan(str)).Write(rnd).Finalize(hash); // import the static test key CKey key = DecodeSecret(strSecret1C); @@ -220,4 +221,47 @@ BOOST_AUTO_TEST_CASE(key_key_negation) BOOST_CHECK(key.GetPubKey().data()[0] == 0x03); } +static CPubKey UnserializePubkey(const std::vector<uint8_t>& data) +{ + CDataStream stream{SER_NETWORK, INIT_PROTO_VERSION}; + stream << data; + CPubKey pubkey; + stream >> pubkey; + return pubkey; +} + +static unsigned int GetLen(unsigned char chHeader) +{ + if (chHeader == 2 || chHeader == 3) + return CPubKey::COMPRESSED_SIZE; + if (chHeader == 4 || chHeader == 6 || chHeader == 7) + return CPubKey::SIZE; + return 0; +} + +static void CmpSerializationPubkey(const CPubKey& pubkey) +{ + CDataStream stream{SER_NETWORK, INIT_PROTO_VERSION}; + stream << pubkey; + CPubKey pubkey2; + stream >> pubkey2; + BOOST_CHECK(pubkey == pubkey2); +} + +BOOST_AUTO_TEST_CASE(pubkey_unserialize) +{ + for (uint8_t i = 2; i <= 7; ++i) { + CPubKey key = UnserializePubkey({0x02}); + BOOST_CHECK(!key.IsValid()); + CmpSerializationPubkey(key); + key = UnserializePubkey(std::vector<uint8_t>(GetLen(i), i)); + CmpSerializationPubkey(key); + if (i == 5) { + BOOST_CHECK(!key.IsValid()); + } else { + BOOST_CHECK(key.IsValid()); + } + } +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/merkle_tests.cpp b/src/test/merkle_tests.cpp index 03dce552fc..9bc7cc5dab 100644 --- a/src/test/merkle_tests.cpp +++ b/src/test/merkle_tests.cpp @@ -13,9 +13,9 @@ static uint256 ComputeMerkleRootFromBranch(const uint256& leaf, const std::vecto uint256 hash = leaf; for (std::vector<uint256>::const_iterator it = vMerkleBranch.begin(); it != vMerkleBranch.end(); ++it) { if (nIndex & 1) { - hash = Hash(it->begin(), it->end(), hash.begin(), hash.end()); + hash = Hash(*it, hash); } else { - hash = Hash(hash.begin(), hash.end(), it->begin(), it->end()); + hash = Hash(hash, *it); } nIndex >>= 1; } @@ -60,7 +60,7 @@ static void MerkleComputation(const std::vector<uint256>& leaves, uint256* proot } } mutated |= (inner[level] == h); - CHash256().Write(inner[level].begin(), 32).Write(h.begin(), 32).Finalize(h.begin()); + CHash256().Write(inner[level]).Write(h).Finalize(h); } // Store the resulting hash at inner position level. inner[level] = h; @@ -86,7 +86,7 @@ static void MerkleComputation(const std::vector<uint256>& leaves, uint256* proot if (pbranch && matchh) { pbranch->push_back(h); } - CHash256().Write(h.begin(), 32).Write(h.begin(), 32).Finalize(h.begin()); + CHash256().Write(h).Write(h).Finalize(h); // Increment count to the value it would have if two entries at this // level had existed. count += (((uint32_t)1) << level); @@ -101,7 +101,7 @@ static void MerkleComputation(const std::vector<uint256>& leaves, uint256* proot matchh = true; } } - CHash256().Write(inner[level].begin(), 32).Write(h.begin(), 32).Finalize(h.begin()); + CHash256().Write(inner[level]).Write(h).Finalize(h); level++; } } @@ -144,8 +144,7 @@ static uint256 BlockBuildMerkleTree(const CBlock& block, bool* fMutated, std::ve // Two identical hashes at the end of the list at a particular level. mutated = true; } - vMerkleTree.push_back(Hash(vMerkleTree[j+i].begin(), vMerkleTree[j+i].end(), - vMerkleTree[j+i2].begin(), vMerkleTree[j+i2].end())); + vMerkleTree.push_back(Hash(vMerkleTree[j+i], vMerkleTree[j+i2])); } j += nSize; } diff --git a/src/test/miner_tests.cpp b/src/test/miner_tests.cpp index 57eee94330..62a0dc4241 100644 --- a/src/test/miner_tests.cpp +++ b/src/test/miner_tests.cpp @@ -253,7 +253,7 @@ BOOST_AUTO_TEST_CASE(CreateNewBlock_validity) pblock->nNonce = blockinfo[i].nonce; } std::shared_ptr<const CBlock> shared_pblock = std::make_shared<const CBlock>(*pblock); - BOOST_CHECK(EnsureChainman(m_node).ProcessNewBlock(chainparams, shared_pblock, true, nullptr)); + BOOST_CHECK(Assert(m_node.chainman)->ProcessNewBlock(chainparams, shared_pblock, true, nullptr)); pblock->hashPrevBlock = pblock->GetHash(); } @@ -448,7 +448,7 @@ BOOST_AUTO_TEST_CASE(CreateNewBlock_validity) m_node.mempool->addUnchecked(entry.Fee(HIGHFEE).Time(GetTime()).SpendsCoinbase(true).FromTx(tx)); BOOST_CHECK(CheckFinalTx(CTransaction(tx), flags)); // Locktime passes BOOST_CHECK(!TestSequenceLocks(CTransaction(tx), flags)); // Sequence locks fail - BOOST_CHECK(SequenceLocks(CTransaction(tx), flags, &prevheights, CreateBlockIndex(::ChainActive().Tip()->nHeight + 2))); // Sequence locks pass on 2nd block + BOOST_CHECK(SequenceLocks(CTransaction(tx), flags, prevheights, CreateBlockIndex(::ChainActive().Tip()->nHeight + 2))); // Sequence locks pass on 2nd block // relative time locked tx.vin[0].prevout.hash = txFirst[1]->GetHash(); @@ -461,7 +461,7 @@ BOOST_AUTO_TEST_CASE(CreateNewBlock_validity) for (int i = 0; i < CBlockIndex::nMedianTimeSpan; i++) ::ChainActive().Tip()->GetAncestor(::ChainActive().Tip()->nHeight - i)->nTime += 512; //Trick the MedianTimePast - BOOST_CHECK(SequenceLocks(CTransaction(tx), flags, &prevheights, CreateBlockIndex(::ChainActive().Tip()->nHeight + 1))); // Sequence locks pass 512 seconds later + BOOST_CHECK(SequenceLocks(CTransaction(tx), flags, prevheights, CreateBlockIndex(::ChainActive().Tip()->nHeight + 1))); // Sequence locks pass 512 seconds later for (int i = 0; i < CBlockIndex::nMedianTimeSpan; i++) ::ChainActive().Tip()->GetAncestor(::ChainActive().Tip()->nHeight - i)->nTime -= 512; //undo tricked MTP diff --git a/src/test/multisig_tests.cpp b/src/test/multisig_tests.cpp index dd2890c134..e14d2dd72d 100644 --- a/src/test/multisig_tests.cpp +++ b/src/test/multisig_tests.cpp @@ -141,7 +141,7 @@ BOOST_AUTO_TEST_CASE(multisig_IsStandard) for (int i = 0; i < 4; i++) key[i].MakeNewKey(true); - txnouttype whichType; + TxoutType whichType; CScript a_and_b; a_and_b << OP_2 << ToByteVector(key[0].GetPubKey()) << ToByteVector(key[1].GetPubKey()) << OP_2 << OP_CHECKMULTISIG; diff --git a/src/test/net_tests.cpp b/src/test/net_tests.cpp index 84bf593497..317000c771 100644 --- a/src/test/net_tests.cpp +++ b/src/test/net_tests.cpp @@ -6,6 +6,7 @@ #include <addrman.h> #include <chainparams.h> #include <clientversion.h> +#include <cstdint> #include <net.h> #include <netbase.h> #include <serialize.h> @@ -83,10 +84,10 @@ BOOST_FIXTURE_TEST_SUITE(net_tests, BasicTestingSetup) BOOST_AUTO_TEST_CASE(cnode_listen_port) { // test default - unsigned short port = GetListenPort(); + uint16_t port = GetListenPort(); BOOST_CHECK(port == Params().GetDefaultPort()); // test set port - unsigned short altPort = 12345; + uint16_t altPort = 12345; BOOST_CHECK(gArgs.SoftSetArg("-port", ToString(altPort))); port = GetListenPort(); BOOST_CHECK(port == altPort); @@ -179,17 +180,12 @@ BOOST_AUTO_TEST_CASE(cnode_simple_test) CAddress addr = CAddress(CService(ipv4Addr, 7777), NODE_NETWORK); std::string pszDest; - bool fInboundIn = false; - // Test that fFeeler is false by default. - std::unique_ptr<CNode> pnode1 = MakeUnique<CNode>(id++, NODE_NETWORK, height, hSocket, addr, 0, 0, CAddress(), pszDest, fInboundIn); - BOOST_CHECK(pnode1->fInbound == false); - BOOST_CHECK(pnode1->fFeeler == false); + std::unique_ptr<CNode> pnode1 = MakeUnique<CNode>(id++, NODE_NETWORK, height, hSocket, addr, 0, 0, CAddress(), pszDest, ConnectionType::OUTBOUND); + BOOST_CHECK(pnode1->IsInboundConn() == false); - fInboundIn = true; - std::unique_ptr<CNode> pnode2 = MakeUnique<CNode>(id++, NODE_NETWORK, height, hSocket, addr, 1, 1, CAddress(), pszDest, fInboundIn); - BOOST_CHECK(pnode2->fInbound == true); - BOOST_CHECK(pnode2->fFeeler == false); + std::unique_ptr<CNode> pnode2 = MakeUnique<CNode>(id++, NODE_NETWORK, height, hSocket, addr, 1, 1, CAddress(), pszDest, ConnectionType::INBOUND); + BOOST_CHECK(pnode2->IsInboundConn() == true); } // prior to PR #14728, this test triggers an undefined behavior @@ -213,7 +209,7 @@ BOOST_AUTO_TEST_CASE(ipv4_peer_with_ipv6_addrMe_test) in_addr ipv4AddrPeer; ipv4AddrPeer.s_addr = 0xa0b0c001; CAddress addr = CAddress(CService(ipv4AddrPeer, 7777), NODE_NETWORK); - std::unique_ptr<CNode> pnode = MakeUnique<CNode>(0, NODE_NETWORK, 0, INVALID_SOCKET, addr, 0, 0, CAddress{}, std::string{}, false); + std::unique_ptr<CNode> pnode = MakeUnique<CNode>(0, NODE_NETWORK, 0, INVALID_SOCKET, addr, 0, 0, CAddress{}, std::string{}, ConnectionType::OUTBOUND); pnode->fSuccessfullyConnected.store(true); // the peer claims to be reaching us via IPv6 diff --git a/src/test/netbase_tests.cpp b/src/test/netbase_tests.cpp index 2e1972cc3f..49073ea657 100644 --- a/src/test/netbase_tests.cpp +++ b/src/test/netbase_tests.cpp @@ -6,6 +6,7 @@ #include <netbase.h> #include <test/util/setup_common.h> #include <util/strencodings.h> +#include <util/translation.h> #include <string> @@ -137,6 +138,14 @@ BOOST_AUTO_TEST_CASE(onioncat_test) } +BOOST_AUTO_TEST_CASE(embedded_test) +{ + CNetAddr addr1(ResolveIP("1.2.3.4")); + CNetAddr addr2(ResolveIP("::FFFF:0102:0304")); + BOOST_CHECK(addr2.IsIPv4()); + BOOST_CHECK_EQUAL(addr1.ToString(), addr2.ToString()); +} + BOOST_AUTO_TEST_CASE(subnet_test) { @@ -157,9 +166,13 @@ BOOST_AUTO_TEST_CASE(subnet_test) BOOST_CHECK(ResolveSubNet("1.2.2.1/24").Match(ResolveIP("1.2.2.4"))); BOOST_CHECK(ResolveSubNet("1.2.2.110/31").Match(ResolveIP("1.2.2.111"))); BOOST_CHECK(ResolveSubNet("1.2.2.20/26").Match(ResolveIP("1.2.2.63"))); - // All-Matching IPv6 Matches arbitrary IPv4 and IPv6 + // All-Matching IPv6 Matches arbitrary IPv6 BOOST_CHECK(ResolveSubNet("::/0").Match(ResolveIP("1:2:3:4:5:6:7:1234"))); - BOOST_CHECK(ResolveSubNet("::/0").Match(ResolveIP("1.2.3.4"))); + // But not `::` or `0.0.0.0` because they are considered invalid addresses + BOOST_CHECK(!ResolveSubNet("::/0").Match(ResolveIP("::"))); + BOOST_CHECK(!ResolveSubNet("::/0").Match(ResolveIP("0.0.0.0"))); + // Addresses from one network (IPv4) don't belong to subnets of another network (IPv6) + BOOST_CHECK(!ResolveSubNet("::/0").Match(ResolveIP("1.2.3.4"))); // All-Matching IPv4 does not Match IPv6 BOOST_CHECK(!ResolveSubNet("0.0.0.0/0").Match(ResolveIP("1:2:3:4:5:6:7:1234"))); // Invalid subnets Match nothing (not even invalid addresses) @@ -325,15 +338,15 @@ BOOST_AUTO_TEST_CASE(netbase_parsenetwork) BOOST_AUTO_TEST_CASE(netpermissions_test) { - std::string error; + bilingual_str error; NetWhitebindPermissions whitebindPermissions; NetWhitelistPermissions whitelistPermissions; // Detect invalid white bind BOOST_CHECK(!NetWhitebindPermissions::TryParse("", whitebindPermissions, error)); - BOOST_CHECK(error.find("Cannot resolve -whitebind address") != std::string::npos); + BOOST_CHECK(error.original.find("Cannot resolve -whitebind address") != std::string::npos); BOOST_CHECK(!NetWhitebindPermissions::TryParse("127.0.0.1", whitebindPermissions, error)); - BOOST_CHECK(error.find("Need to specify a port with -whitebind") != std::string::npos); + BOOST_CHECK(error.original.find("Need to specify a port with -whitebind") != std::string::npos); BOOST_CHECK(!NetWhitebindPermissions::TryParse("", whitebindPermissions, error)); // If no permission flags, assume backward compatibility @@ -377,11 +390,11 @@ BOOST_AUTO_TEST_CASE(netpermissions_test) // Detect invalid flag BOOST_CHECK(!NetWhitebindPermissions::TryParse("bloom,forcerelay,oopsie@1.2.3.4:32", whitebindPermissions, error)); - BOOST_CHECK(error.find("Invalid P2P permission") != std::string::npos); + BOOST_CHECK(error.original.find("Invalid P2P permission") != std::string::npos); - // Check whitelist error + // Check netmask error BOOST_CHECK(!NetWhitelistPermissions::TryParse("bloom,forcerelay,noban@1.2.3.4:32", whitelistPermissions, error)); - BOOST_CHECK(error.find("Invalid netmask specified in -whitelist") != std::string::npos); + BOOST_CHECK(error.original.find("Invalid netmask specified in -whitelist") != std::string::npos); // Happy path for whitelist parsing BOOST_CHECK(NetWhitelistPermissions::TryParse("noban@1.2.3.4", whitelistPermissions, error)); @@ -393,12 +406,14 @@ BOOST_AUTO_TEST_CASE(netpermissions_test) BOOST_CHECK(NetWhitelistPermissions::TryParse("bloom,forcerelay,noban,relay,mempool@1.2.3.4/32", whitelistPermissions, error)); const auto strings = NetPermissions::ToStrings(PF_ALL); - BOOST_CHECK_EQUAL(strings.size(), 5U); + BOOST_CHECK_EQUAL(strings.size(), 7U); BOOST_CHECK(std::find(strings.begin(), strings.end(), "bloomfilter") != strings.end()); BOOST_CHECK(std::find(strings.begin(), strings.end(), "forcerelay") != strings.end()); BOOST_CHECK(std::find(strings.begin(), strings.end(), "relay") != strings.end()); BOOST_CHECK(std::find(strings.begin(), strings.end(), "noban") != strings.end()); BOOST_CHECK(std::find(strings.begin(), strings.end(), "mempool") != strings.end()); + BOOST_CHECK(std::find(strings.begin(), strings.end(), "download") != strings.end()); + BOOST_CHECK(std::find(strings.begin(), strings.end(), "addr") != strings.end()); } BOOST_AUTO_TEST_CASE(netbase_dont_resolve_strings_with_embedded_nul_characters) diff --git a/src/test/policy_fee_tests.cpp b/src/test/policy_fee_tests.cpp new file mode 100644 index 0000000000..6d8872b11e --- /dev/null +++ b/src/test/policy_fee_tests.cpp @@ -0,0 +1,34 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <amount.h> +#include <policy/fees.h> + +#include <test/util/setup_common.h> + +#include <boost/test/unit_test.hpp> + +BOOST_FIXTURE_TEST_SUITE(policy_fee_tests, BasicTestingSetup) + +BOOST_AUTO_TEST_CASE(FeeRounder) +{ + FeeFilterRounder fee_rounder{CFeeRate{1000}}; + + // check that 1000 rounds to 974 or 1071 + std::set<CAmount> results; + while (results.size() < 2) { + results.emplace(fee_rounder.round(1000)); + } + BOOST_CHECK_EQUAL(*results.begin(), 974); + BOOST_CHECK_EQUAL(*++results.begin(), 1071); + + // check that negative amounts rounds to 0 + BOOST_CHECK_EQUAL(fee_rounder.round(-0), 0); + BOOST_CHECK_EQUAL(fee_rounder.round(-1), 0); + + // check that MAX_MONEY rounds to 9170997 + BOOST_CHECK_EQUAL(fee_rounder.round(MAX_MONEY), 9170997); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/scheduler_tests.cpp b/src/test/scheduler_tests.cpp index 1395a7f38c..2e5a7549b7 100644 --- a/src/test/scheduler_tests.cpp +++ b/src/test/scheduler_tests.cpp @@ -7,7 +7,7 @@ #include <util/time.h> #include <boost/test/unit_test.hpp> -#include <boost/thread.hpp> +#include <boost/thread/thread.hpp> #include <mutex> @@ -89,7 +89,7 @@ BOOST_AUTO_TEST_CASE(manythreads) } // Drain the task queue then exit threads - microTasks.stop(true); + microTasks.StopWhenDrained(); microThreads.join_all(); // ... wait until all the threads are done int counterSum = 0; @@ -155,7 +155,7 @@ BOOST_AUTO_TEST_CASE(singlethreadedscheduler_ordered) } // finish up - scheduler.stop(true); + scheduler.StopWhenDrained(); threads.join_all(); BOOST_CHECK_EQUAL(counter1, 100); @@ -186,7 +186,7 @@ BOOST_AUTO_TEST_CASE(mockforward) scheduler.MockForward(std::chrono::minutes{5}); // ensure scheduler has chance to process all tasks queued for before 1 ms from now. - scheduler.scheduleFromNow([&scheduler] { scheduler.stop(false); }, std::chrono::milliseconds{1}); + scheduler.scheduleFromNow([&scheduler] { scheduler.stop(); }, std::chrono::milliseconds{1}); scheduler_thread.join(); // check that the queue only has one job remaining diff --git a/src/test/script_standard_tests.cpp b/src/test/script_standard_tests.cpp index b185d3b4ac..87678af4d1 100644 --- a/src/test/script_standard_tests.cpp +++ b/src/test/script_standard_tests.cpp @@ -31,35 +31,35 @@ BOOST_AUTO_TEST_CASE(script_standard_Solver_success) CScript s; std::vector<std::vector<unsigned char> > solutions; - // TX_PUBKEY + // TxoutType::PUBKEY s.clear(); s << ToByteVector(pubkeys[0]) << OP_CHECKSIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_PUBKEY); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::PUBKEY); BOOST_CHECK_EQUAL(solutions.size(), 1U); BOOST_CHECK(solutions[0] == ToByteVector(pubkeys[0])); - // TX_PUBKEYHASH + // TxoutType::PUBKEYHASH s.clear(); s << OP_DUP << OP_HASH160 << ToByteVector(pubkeys[0].GetID()) << OP_EQUALVERIFY << OP_CHECKSIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_PUBKEYHASH); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::PUBKEYHASH); BOOST_CHECK_EQUAL(solutions.size(), 1U); BOOST_CHECK(solutions[0] == ToByteVector(pubkeys[0].GetID())); - // TX_SCRIPTHASH + // TxoutType::SCRIPTHASH CScript redeemScript(s); // initialize with leftover P2PKH script s.clear(); s << OP_HASH160 << ToByteVector(CScriptID(redeemScript)) << OP_EQUAL; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_SCRIPTHASH); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::SCRIPTHASH); BOOST_CHECK_EQUAL(solutions.size(), 1U); BOOST_CHECK(solutions[0] == ToByteVector(CScriptID(redeemScript))); - // TX_MULTISIG + // TxoutType::MULTISIG s.clear(); s << OP_1 << ToByteVector(pubkeys[0]) << ToByteVector(pubkeys[1]) << OP_2 << OP_CHECKMULTISIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_MULTISIG); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::MULTISIG); BOOST_CHECK_EQUAL(solutions.size(), 4U); BOOST_CHECK(solutions[0] == std::vector<unsigned char>({1})); BOOST_CHECK(solutions[1] == ToByteVector(pubkeys[0])); @@ -72,7 +72,7 @@ BOOST_AUTO_TEST_CASE(script_standard_Solver_success) ToByteVector(pubkeys[1]) << ToByteVector(pubkeys[2]) << OP_3 << OP_CHECKMULTISIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_MULTISIG); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::MULTISIG); BOOST_CHECK_EQUAL(solutions.size(), 5U); BOOST_CHECK(solutions[0] == std::vector<unsigned char>({2})); BOOST_CHECK(solutions[1] == ToByteVector(pubkeys[0])); @@ -80,37 +80,37 @@ BOOST_AUTO_TEST_CASE(script_standard_Solver_success) BOOST_CHECK(solutions[3] == ToByteVector(pubkeys[2])); BOOST_CHECK(solutions[4] == std::vector<unsigned char>({3})); - // TX_NULL_DATA + // TxoutType::NULL_DATA s.clear(); s << OP_RETURN << std::vector<unsigned char>({0}) << std::vector<unsigned char>({75}) << std::vector<unsigned char>({255}); - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NULL_DATA); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NULL_DATA); BOOST_CHECK_EQUAL(solutions.size(), 0U); - // TX_WITNESS_V0_KEYHASH + // TxoutType::WITNESS_V0_KEYHASH s.clear(); s << OP_0 << ToByteVector(pubkeys[0].GetID()); - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_WITNESS_V0_KEYHASH); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::WITNESS_V0_KEYHASH); BOOST_CHECK_EQUAL(solutions.size(), 1U); BOOST_CHECK(solutions[0] == ToByteVector(pubkeys[0].GetID())); - // TX_WITNESS_V0_SCRIPTHASH + // TxoutType::WITNESS_V0_SCRIPTHASH uint256 scriptHash; CSHA256().Write(&redeemScript[0], redeemScript.size()) .Finalize(scriptHash.begin()); s.clear(); s << OP_0 << ToByteVector(scriptHash); - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_WITNESS_V0_SCRIPTHASH); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::WITNESS_V0_SCRIPTHASH); BOOST_CHECK_EQUAL(solutions.size(), 1U); BOOST_CHECK(solutions[0] == ToByteVector(scriptHash)); - // TX_NONSTANDARD + // TxoutType::NONSTANDARD s.clear(); s << OP_9 << OP_ADD << OP_11 << OP_EQUAL; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); } BOOST_AUTO_TEST_CASE(script_standard_Solver_failure) @@ -123,50 +123,50 @@ BOOST_AUTO_TEST_CASE(script_standard_Solver_failure) CScript s; std::vector<std::vector<unsigned char> > solutions; - // TX_PUBKEY with incorrectly sized pubkey + // TxoutType::PUBKEY with incorrectly sized pubkey s.clear(); s << std::vector<unsigned char>(30, 0x01) << OP_CHECKSIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); - // TX_PUBKEYHASH with incorrectly sized key hash + // TxoutType::PUBKEYHASH with incorrectly sized key hash s.clear(); s << OP_DUP << OP_HASH160 << ToByteVector(pubkey) << OP_EQUALVERIFY << OP_CHECKSIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); - // TX_SCRIPTHASH with incorrectly sized script hash + // TxoutType::SCRIPTHASH with incorrectly sized script hash s.clear(); s << OP_HASH160 << std::vector<unsigned char>(21, 0x01) << OP_EQUAL; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); - // TX_MULTISIG 0/2 + // TxoutType::MULTISIG 0/2 s.clear(); s << OP_0 << ToByteVector(pubkey) << OP_1 << OP_CHECKMULTISIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); - // TX_MULTISIG 2/1 + // TxoutType::MULTISIG 2/1 s.clear(); s << OP_2 << ToByteVector(pubkey) << OP_1 << OP_CHECKMULTISIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); - // TX_MULTISIG n = 2 with 1 pubkey + // TxoutType::MULTISIG n = 2 with 1 pubkey s.clear(); s << OP_1 << ToByteVector(pubkey) << OP_2 << OP_CHECKMULTISIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); - // TX_MULTISIG n = 1 with 0 pubkeys + // TxoutType::MULTISIG n = 1 with 0 pubkeys s.clear(); s << OP_1 << OP_1 << OP_CHECKMULTISIG; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); - // TX_NULL_DATA with other opcodes + // TxoutType::NULL_DATA with other opcodes s.clear(); s << OP_RETURN << std::vector<unsigned char>({75}) << OP_ADD; - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); - // TX_WITNESS with incorrect program size + // TxoutType::WITNESS_UNKNOWN with incorrect program size s.clear(); s << OP_0 << std::vector<unsigned char>(19, 0x01); - BOOST_CHECK_EQUAL(Solver(s, solutions), TX_NONSTANDARD); + BOOST_CHECK_EQUAL(Solver(s, solutions), TxoutType::NONSTANDARD); } BOOST_AUTO_TEST_CASE(script_standard_ExtractDestination) @@ -179,21 +179,21 @@ BOOST_AUTO_TEST_CASE(script_standard_ExtractDestination) CScript s; CTxDestination address; - // TX_PUBKEY + // TxoutType::PUBKEY s.clear(); s << ToByteVector(pubkey) << OP_CHECKSIG; BOOST_CHECK(ExtractDestination(s, address)); BOOST_CHECK(boost::get<PKHash>(&address) && *boost::get<PKHash>(&address) == PKHash(pubkey)); - // TX_PUBKEYHASH + // TxoutType::PUBKEYHASH s.clear(); s << OP_DUP << OP_HASH160 << ToByteVector(pubkey.GetID()) << OP_EQUALVERIFY << OP_CHECKSIG; BOOST_CHECK(ExtractDestination(s, address)); BOOST_CHECK(boost::get<PKHash>(&address) && *boost::get<PKHash>(&address) == PKHash(pubkey)); - // TX_SCRIPTHASH + // TxoutType::SCRIPTHASH CScript redeemScript(s); // initialize with leftover P2PKH script s.clear(); s << OP_HASH160 << ToByteVector(CScriptID(redeemScript)) << OP_EQUAL; @@ -201,25 +201,25 @@ BOOST_AUTO_TEST_CASE(script_standard_ExtractDestination) BOOST_CHECK(boost::get<ScriptHash>(&address) && *boost::get<ScriptHash>(&address) == ScriptHash(redeemScript)); - // TX_MULTISIG + // TxoutType::MULTISIG s.clear(); s << OP_1 << ToByteVector(pubkey) << OP_1 << OP_CHECKMULTISIG; BOOST_CHECK(!ExtractDestination(s, address)); - // TX_NULL_DATA + // TxoutType::NULL_DATA s.clear(); s << OP_RETURN << std::vector<unsigned char>({75}); BOOST_CHECK(!ExtractDestination(s, address)); - // TX_WITNESS_V0_KEYHASH + // TxoutType::WITNESS_V0_KEYHASH s.clear(); s << OP_0 << ToByteVector(pubkey.GetID()); BOOST_CHECK(ExtractDestination(s, address)); WitnessV0KeyHash keyhash; - CHash160().Write(pubkey.begin(), pubkey.size()).Finalize(keyhash.begin()); + CHash160().Write(pubkey).Finalize(keyhash); BOOST_CHECK(boost::get<WitnessV0KeyHash>(&address) && *boost::get<WitnessV0KeyHash>(&address) == keyhash); - // TX_WITNESS_V0_SCRIPTHASH + // TxoutType::WITNESS_V0_SCRIPTHASH s.clear(); WitnessV0ScriptHash scripthash; CSHA256().Write(redeemScript.data(), redeemScript.size()).Finalize(scripthash.begin()); @@ -227,7 +227,7 @@ BOOST_AUTO_TEST_CASE(script_standard_ExtractDestination) BOOST_CHECK(ExtractDestination(s, address)); BOOST_CHECK(boost::get<WitnessV0ScriptHash>(&address) && *boost::get<WitnessV0ScriptHash>(&address) == scripthash); - // TX_WITNESS with unknown version + // TxoutType::WITNESS_UNKNOWN with unknown version s.clear(); s << OP_1 << ToByteVector(pubkey); BOOST_CHECK(ExtractDestination(s, address)); @@ -248,49 +248,49 @@ BOOST_AUTO_TEST_CASE(script_standard_ExtractDestinations) } CScript s; - txnouttype whichType; + TxoutType whichType; std::vector<CTxDestination> addresses; int nRequired; - // TX_PUBKEY + // TxoutType::PUBKEY s.clear(); s << ToByteVector(pubkeys[0]) << OP_CHECKSIG; BOOST_CHECK(ExtractDestinations(s, whichType, addresses, nRequired)); - BOOST_CHECK_EQUAL(whichType, TX_PUBKEY); + BOOST_CHECK_EQUAL(whichType, TxoutType::PUBKEY); BOOST_CHECK_EQUAL(addresses.size(), 1U); BOOST_CHECK_EQUAL(nRequired, 1); BOOST_CHECK(boost::get<PKHash>(&addresses[0]) && *boost::get<PKHash>(&addresses[0]) == PKHash(pubkeys[0])); - // TX_PUBKEYHASH + // TxoutType::PUBKEYHASH s.clear(); s << OP_DUP << OP_HASH160 << ToByteVector(pubkeys[0].GetID()) << OP_EQUALVERIFY << OP_CHECKSIG; BOOST_CHECK(ExtractDestinations(s, whichType, addresses, nRequired)); - BOOST_CHECK_EQUAL(whichType, TX_PUBKEYHASH); + BOOST_CHECK_EQUAL(whichType, TxoutType::PUBKEYHASH); BOOST_CHECK_EQUAL(addresses.size(), 1U); BOOST_CHECK_EQUAL(nRequired, 1); BOOST_CHECK(boost::get<PKHash>(&addresses[0]) && *boost::get<PKHash>(&addresses[0]) == PKHash(pubkeys[0])); - // TX_SCRIPTHASH + // TxoutType::SCRIPTHASH CScript redeemScript(s); // initialize with leftover P2PKH script s.clear(); s << OP_HASH160 << ToByteVector(CScriptID(redeemScript)) << OP_EQUAL; BOOST_CHECK(ExtractDestinations(s, whichType, addresses, nRequired)); - BOOST_CHECK_EQUAL(whichType, TX_SCRIPTHASH); + BOOST_CHECK_EQUAL(whichType, TxoutType::SCRIPTHASH); BOOST_CHECK_EQUAL(addresses.size(), 1U); BOOST_CHECK_EQUAL(nRequired, 1); BOOST_CHECK(boost::get<ScriptHash>(&addresses[0]) && *boost::get<ScriptHash>(&addresses[0]) == ScriptHash(redeemScript)); - // TX_MULTISIG + // TxoutType::MULTISIG s.clear(); s << OP_2 << ToByteVector(pubkeys[0]) << ToByteVector(pubkeys[1]) << OP_2 << OP_CHECKMULTISIG; BOOST_CHECK(ExtractDestinations(s, whichType, addresses, nRequired)); - BOOST_CHECK_EQUAL(whichType, TX_MULTISIG); + BOOST_CHECK_EQUAL(whichType, TxoutType::MULTISIG); BOOST_CHECK_EQUAL(addresses.size(), 2U); BOOST_CHECK_EQUAL(nRequired, 2); BOOST_CHECK(boost::get<PKHash>(&addresses[0]) && @@ -298,7 +298,7 @@ BOOST_AUTO_TEST_CASE(script_standard_ExtractDestinations) BOOST_CHECK(boost::get<PKHash>(&addresses[1]) && *boost::get<PKHash>(&addresses[1]) == PKHash(pubkeys[1])); - // TX_NULL_DATA + // TxoutType::NULL_DATA s.clear(); s << OP_RETURN << std::vector<unsigned char>({75}); BOOST_CHECK(!ExtractDestinations(s, whichType, addresses, nRequired)); diff --git a/src/test/script_tests.cpp b/src/test/script_tests.cpp index cb3ae290d1..0830743d61 100644 --- a/src/test/script_tests.cpp +++ b/src/test/script_tests.cpp @@ -282,7 +282,7 @@ public: CScript scriptPubKey = script; if (wm == WitnessMode::PKH) { uint160 hash; - CHash160().Write(&script[1], script.size() - 1).Finalize(hash.begin()); + CHash160().Write(MakeSpan(script).subspan(1)).Finalize(hash); script = CScript() << OP_DUP << OP_HASH160 << ToByteVector(hash) << OP_EQUALVERIFY << OP_CHECKSIG; scriptPubKey = CScript() << witnessversion << ToByteVector(hash); } else if (wm == WitnessMode::SH) { diff --git a/src/test/serialize_tests.cpp b/src/test/serialize_tests.cpp index c2328f931c..f625b67c2a 100644 --- a/src/test/serialize_tests.cpp +++ b/src/test/serialize_tests.cpp @@ -145,7 +145,7 @@ BOOST_AUTO_TEST_CASE(floats) for (int i = 0; i < 1000; i++) { ss << float(i); } - BOOST_CHECK(Hash(ss.begin(), ss.end()) == uint256S("8e8b4cf3e4df8b332057e3e23af42ebc663b61e0495d5e7e32d85099d7f3fe0c")); + BOOST_CHECK(Hash(ss) == uint256S("8e8b4cf3e4df8b332057e3e23af42ebc663b61e0495d5e7e32d85099d7f3fe0c")); // decode for (int i = 0; i < 1000; i++) { @@ -162,7 +162,7 @@ BOOST_AUTO_TEST_CASE(doubles) for (int i = 0; i < 1000; i++) { ss << double(i); } - BOOST_CHECK(Hash(ss.begin(), ss.end()) == uint256S("43d0c82591953c4eafe114590d392676a01585d25b25d433557f0d7878b23f96")); + BOOST_CHECK(Hash(ss) == uint256S("43d0c82591953c4eafe114590d392676a01585d25b25d433557f0d7878b23f96")); // decode for (int i = 0; i < 1000; i++) { diff --git a/src/test/settings_tests.cpp b/src/test/settings_tests.cpp index fcd831ccda..548fd020a6 100644 --- a/src/test/settings_tests.cpp +++ b/src/test/settings_tests.cpp @@ -12,10 +12,90 @@ #include <univalue.h> #include <util/strencodings.h> #include <util/string.h> +#include <util/system.h> #include <vector> +inline bool operator==(const util::SettingsValue& a, const util::SettingsValue& b) +{ + return a.write() == b.write(); +} + +inline std::ostream& operator<<(std::ostream& os, const util::SettingsValue& value) +{ + os << value.write(); + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const std::pair<std::string, util::SettingsValue>& kv) +{ + util::SettingsValue out(util::SettingsValue::VOBJ); + out.__pushKV(kv.first, kv.second); + os << out.write(); + return os; +} + +inline void WriteText(const fs::path& path, const std::string& text) +{ + fsbridge::ofstream file; + file.open(path); + file << text; +} + BOOST_FIXTURE_TEST_SUITE(settings_tests, BasicTestingSetup) +BOOST_AUTO_TEST_CASE(ReadWrite) +{ + fs::path path = GetDataDir() / "settings.json"; + + WriteText(path, R"({ + "string": "string", + "num": 5, + "bool": true, + "null": null + })"); + + std::map<std::string, util::SettingsValue> expected{ + {"string", "string"}, + {"num", 5}, + {"bool", true}, + {"null", {}}, + }; + + // Check file read. + std::map<std::string, util::SettingsValue> values; + std::vector<std::string> errors; + BOOST_CHECK(util::ReadSettings(path, values, errors)); + BOOST_CHECK_EQUAL_COLLECTIONS(values.begin(), values.end(), expected.begin(), expected.end()); + BOOST_CHECK(errors.empty()); + + // Check no errors if file doesn't exist. + fs::remove(path); + BOOST_CHECK(util::ReadSettings(path, values, errors)); + BOOST_CHECK(values.empty()); + BOOST_CHECK(errors.empty()); + + // Check duplicate keys not allowed + WriteText(path, R"({ + "dupe": "string", + "dupe": "dupe" + })"); + BOOST_CHECK(!util::ReadSettings(path, values, errors)); + std::vector<std::string> dup_keys = {strprintf("Found duplicate key dupe in settings file %s", path.string())}; + BOOST_CHECK_EQUAL_COLLECTIONS(errors.begin(), errors.end(), dup_keys.begin(), dup_keys.end()); + + // Check non-kv json files not allowed + WriteText(path, R"("non-kv")"); + BOOST_CHECK(!util::ReadSettings(path, values, errors)); + std::vector<std::string> non_kv = {strprintf("Found non-object value \"non-kv\" in settings file %s", path.string())}; + BOOST_CHECK_EQUAL_COLLECTIONS(errors.begin(), errors.end(), non_kv.begin(), non_kv.end()); + + // Check invalid json not allowed + WriteText(path, R"(invalid json)"); + BOOST_CHECK(!util::ReadSettings(path, values, errors)); + std::vector<std::string> fail_parse = {strprintf("Unable to parse settings file %s", path.string())}; + BOOST_CHECK_EQUAL_COLLECTIONS(errors.begin(), errors.end(), fail_parse.begin(), fail_parse.end()); +} + //! Check settings struct contents against expected json strings. static void CheckValues(const util::Settings& settings, const std::string& single_val, const std::string& list_val) { @@ -148,7 +228,7 @@ BOOST_FIXTURE_TEST_CASE(Merge, MergeTestingSetup) if (OnlyHasDefaultSectionSetting(settings, network, name)) desc += " ignored"; desc += "\n"; - out_sha.Write((const unsigned char*)desc.data(), desc.size()); + out_sha.Write(MakeUCharSpan(desc)); if (out_file) { BOOST_REQUIRE(fwrite(desc.data(), 1, desc.size(), out_file) == desc.size()); } @@ -161,7 +241,7 @@ BOOST_FIXTURE_TEST_CASE(Merge, MergeTestingSetup) unsigned char out_sha_bytes[CSHA256::OUTPUT_SIZE]; out_sha.Finalize(out_sha_bytes); - std::string out_sha_hex = HexStr(std::begin(out_sha_bytes), std::end(out_sha_bytes)); + std::string out_sha_hex = HexStr(out_sha_bytes); // If check below fails, should manually dump the results with: // diff --git a/src/test/sighash_tests.cpp b/src/test/sighash_tests.cpp index 5ca136ea6e..c0bb92258b 100644 --- a/src/test/sighash_tests.cpp +++ b/src/test/sighash_tests.cpp @@ -141,7 +141,7 @@ BOOST_AUTO_TEST_CASE(sighash_test) ss << txTo; std::cout << "\t[\"" ; - std::cout << HexStr(ss.begin(), ss.end()) << "\", \""; + std::cout << HexStr(ss) << "\", \""; std::cout << HexStr(scriptCode) << "\", "; std::cout << nIn << ", "; std::cout << nHashType << ", \""; diff --git a/src/test/sync_tests.cpp b/src/test/sync_tests.cpp index 5c6c2ee38e..19029ebd3c 100644 --- a/src/test/sync_tests.cpp +++ b/src/test/sync_tests.cpp @@ -14,13 +14,15 @@ void TestPotentialDeadLockDetected(MutexType& mutex1, MutexType& mutex2) { LOCK2(mutex1, mutex2); } + BOOST_CHECK(LockStackEmpty()); bool error_thrown = false; try { LOCK2(mutex2, mutex1); } catch (const std::logic_error& e) { - BOOST_CHECK_EQUAL(e.what(), "potential deadlock detected"); + BOOST_CHECK_EQUAL(e.what(), "potential deadlock detected: mutex1 -> mutex2 -> mutex1"); error_thrown = true; } + BOOST_CHECK(LockStackEmpty()); #ifdef DEBUG_LOCKORDER BOOST_CHECK(error_thrown); #else @@ -40,9 +42,13 @@ BOOST_AUTO_TEST_CASE(potential_deadlock_detected) RecursiveMutex rmutex1, rmutex2; TestPotentialDeadLockDetected(rmutex1, rmutex2); + // The second test ensures that lock tracking data have not been broken by exception. + TestPotentialDeadLockDetected(rmutex1, rmutex2); Mutex mutex1, mutex2; TestPotentialDeadLockDetected(mutex1, mutex2); + // The second test ensures that lock tracking data have not been broken by exception. + TestPotentialDeadLockDetected(mutex1, mutex2); #ifdef DEBUG_LOCKORDER g_debug_lockorder_abort = prev; diff --git a/src/test/system_tests.cpp b/src/test/system_tests.cpp new file mode 100644 index 0000000000..a55145c738 --- /dev/null +++ b/src/test/system_tests.cpp @@ -0,0 +1,95 @@ +// Copyright (c) 2019 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. +// +#include <test/util/setup_common.h> +#include <util/system.h> +#include <univalue.h> + +#ifdef HAVE_BOOST_PROCESS +#include <boost/process.hpp> +#endif // HAVE_BOOST_PROCESS + +#include <boost/test/unit_test.hpp> + +BOOST_FIXTURE_TEST_SUITE(system_tests, BasicTestingSetup) + +// At least one test is required (in case HAVE_BOOST_PROCESS is not defined). +// Workaround for https://github.com/bitcoin/bitcoin/issues/19128 +BOOST_AUTO_TEST_CASE(dummy) +{ + BOOST_CHECK(true); +} + +#ifdef HAVE_BOOST_PROCESS + +bool checkMessage(const std::runtime_error& ex) +{ + // On Linux & Mac: "No such file or directory" + // On Windows: "The system cannot find the file specified." + const std::string what(ex.what()); + BOOST_CHECK(what.find("file") != std::string::npos); + return true; +} + +bool checkMessageFalse(const std::runtime_error& ex) +{ + BOOST_CHECK_EQUAL(ex.what(), std::string("RunCommandParseJSON error: process(false) returned 1: \n")); + return true; +} + +bool checkMessageStdErr(const std::runtime_error& ex) +{ + const std::string what(ex.what()); + BOOST_CHECK(what.find("RunCommandParseJSON error:") != std::string::npos); + return checkMessage(ex); +} + +BOOST_AUTO_TEST_CASE(run_command) +{ + { + const UniValue result = RunCommandParseJSON(""); + BOOST_CHECK(result.isNull()); + } + { +#ifdef WIN32 + // Windows requires single quotes to prevent escaping double quotes from the JSON... + const UniValue result = RunCommandParseJSON("echo '{\"success\": true}'"); +#else + // ... but Linux and macOS echo a single quote if it's used + const UniValue result = RunCommandParseJSON("echo \"{\"success\": true}\""); +#endif + BOOST_CHECK(result.isObject()); + const UniValue& success = find_value(result, "success"); + BOOST_CHECK(!success.isNull()); + BOOST_CHECK_EQUAL(success.getBool(), true); + } + { + // An invalid command is handled by Boost + BOOST_CHECK_EXCEPTION(RunCommandParseJSON("invalid_command"), boost::process::process_error, checkMessage); // Command failed + } + { + // Return non-zero exit code, no output to stderr + BOOST_CHECK_EXCEPTION(RunCommandParseJSON("false"), std::runtime_error, checkMessageFalse); + } + { + // Return non-zero exit code, with error message for stderr + BOOST_CHECK_EXCEPTION(RunCommandParseJSON("ls nosuchfile"), std::runtime_error, checkMessageStdErr); + } + { + BOOST_REQUIRE_THROW(RunCommandParseJSON("echo \"{\""), std::runtime_error); // Unable to parse JSON + } + // Test std::in, except for Windows +#ifndef WIN32 + { + const UniValue result = RunCommandParseJSON("cat", "{\"success\": true}"); + BOOST_CHECK(result.isObject()); + const UniValue& success = find_value(result, "success"); + BOOST_CHECK(!success.isNull()); + BOOST_CHECK_EQUAL(success.getBool(), true); + } +#endif +} +#endif // HAVE_BOOST_PROCESS + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/timedata_tests.cpp b/src/test/timedata_tests.cpp index 8c880babd1..1dcee23bbb 100644 --- a/src/test/timedata_tests.cpp +++ b/src/test/timedata_tests.cpp @@ -9,6 +9,7 @@ #include <test/util/setup_common.h> #include <timedata.h> #include <util/string.h> +#include <util/translation.h> #include <warnings.h> #include <string> @@ -66,7 +67,7 @@ BOOST_AUTO_TEST_CASE(addtimedata) MultiAddTimeData(1, DEFAULT_MAX_TIME_ADJUSTMENT + 1); //filter size 5 } - BOOST_CHECK(GetWarnings(true).find("clock is wrong") != std::string::npos); + BOOST_CHECK(GetWarnings(true).original.find("clock is wrong") != std::string::npos); // nTimeOffset is not changed if the median of offsets exceeds DEFAULT_MAX_TIME_ADJUSTMENT BOOST_CHECK_EQUAL(GetTimeOffset(), 0); diff --git a/src/test/transaction_tests.cpp b/src/test/transaction_tests.cpp index ddbc68f8e2..4bf6e734ce 100644 --- a/src/test/transaction_tests.cpp +++ b/src/test/transaction_tests.cpp @@ -716,12 +716,12 @@ BOOST_AUTO_TEST_CASE(test_IsStandard) BOOST_CHECK(!IsStandardTx(CTransaction(t), reason)); BOOST_CHECK_EQUAL(reason, "scriptpubkey"); - // MAX_OP_RETURN_RELAY-byte TX_NULL_DATA (standard) + // MAX_OP_RETURN_RELAY-byte TxoutType::NULL_DATA (standard) t.vout[0].scriptPubKey = CScript() << OP_RETURN << ParseHex("04678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4cef3804678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4cef38"); BOOST_CHECK_EQUAL(MAX_OP_RETURN_RELAY, t.vout[0].scriptPubKey.size()); BOOST_CHECK(IsStandardTx(CTransaction(t), reason)); - // MAX_OP_RETURN_RELAY+1-byte TX_NULL_DATA (non-standard) + // MAX_OP_RETURN_RELAY+1-byte TxoutType::NULL_DATA (non-standard) t.vout[0].scriptPubKey = CScript() << OP_RETURN << ParseHex("04678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4cef3804678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4cef3800"); BOOST_CHECK_EQUAL(MAX_OP_RETURN_RELAY + 1, t.vout[0].scriptPubKey.size()); reason.clear(); @@ -745,12 +745,12 @@ BOOST_AUTO_TEST_CASE(test_IsStandard) BOOST_CHECK(!IsStandardTx(CTransaction(t), reason)); BOOST_CHECK_EQUAL(reason, "scriptpubkey"); - // TX_NULL_DATA w/o PUSHDATA + // TxoutType::NULL_DATA w/o PUSHDATA t.vout.resize(1); t.vout[0].scriptPubKey = CScript() << OP_RETURN; BOOST_CHECK(IsStandardTx(CTransaction(t), reason)); - // Only one TX_NULL_DATA permitted in all cases + // Only one TxoutType::NULL_DATA permitted in all cases t.vout.resize(2); t.vout[0].scriptPubKey = CScript() << OP_RETURN << ParseHex("04678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4cef38"); t.vout[1].scriptPubKey = CScript() << OP_RETURN << ParseHex("04678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4cef38"); diff --git a/src/test/util/mining.cpp b/src/test/util/mining.cpp index dac7f1a07b..74536ae74c 100644 --- a/src/test/util/mining.cpp +++ b/src/test/util/mining.cpp @@ -11,6 +11,7 @@ #include <node/context.h> #include <pow.h> #include <script/standard.h> +#include <util/check.h> #include <validation.h> CTxIn generatetoaddress(const NodeContext& node, const std::string& address) @@ -31,7 +32,7 @@ CTxIn MineBlock(const NodeContext& node, const CScript& coinbase_scriptPubKey) assert(block->nNonce); } - bool processed{EnsureChainman(node).ProcessNewBlock(Params(), block, true, nullptr)}; + bool processed{Assert(node.chainman)->ProcessNewBlock(Params(), block, true, nullptr)}; assert(processed); return CTxIn{block->vtx[0]->GetHash(), 0}; @@ -39,9 +40,8 @@ CTxIn MineBlock(const NodeContext& node, const CScript& coinbase_scriptPubKey) std::shared_ptr<CBlock> PrepareBlock(const NodeContext& node, const CScript& coinbase_scriptPubKey) { - assert(node.mempool); auto block = std::make_shared<CBlock>( - BlockAssembler{*node.mempool, Params()} + BlockAssembler{*Assert(node.mempool), Params()} .CreateNewBlock(coinbase_scriptPubKey) ->block); diff --git a/src/test/util/setup_common.cpp b/src/test/util/setup_common.cpp index 3b7a7c8d12..b2ae1cb845 100644 --- a/src/test/util/setup_common.cpp +++ b/src/test/util/setup_common.cpp @@ -11,6 +11,7 @@ #include <consensus/validation.h> #include <crypto/sha256.h> #include <init.h> +#include <interfaces/chain.h> #include <miner.h> #include <net.h> #include <net_processing.h> @@ -19,6 +20,7 @@ #include <rpc/blockchain.h> #include <rpc/register.h> #include <rpc/server.h> +#include <scheduler.h> #include <script/sigcache.h> #include <streams.h> #include <txdb.h> @@ -31,6 +33,7 @@ #include <util/vector.h> #include <validation.h> #include <validationinterface.h> +#include <walletinitinterface.h> #include <functional> @@ -74,11 +77,13 @@ BasicTestingSetup::BasicTestingSetup(const std::string& chainName, const std::ve "dummy", "-printtoconsole=0", "-logtimemicros", + "-logthreadnames", "-debug", "-debugexclude=libevent", "-debugexclude=leveldb", }, extra_args); + util::ThreadRename("test"); fs::create_directories(m_path_root); gArgs.ForceSetArg("-datadir", m_path_root.string()); ClearDatadirCache(); @@ -101,6 +106,8 @@ BasicTestingSetup::BasicTestingSetup(const std::string& chainName, const std::ve SetupNetworking(); InitSignatureCache(); InitScriptExecutionCache(); + m_node.chain = interfaces::MakeChain(m_node); + g_wallet_init_interface.Construct(m_node); fCheckBlockIndex = true; static bool noui_connected = false; if (!noui_connected) { @@ -129,7 +136,7 @@ TestingSetup::TestingSetup(const std::string& chainName, const std::vector<const // We have to run a scheduler thread to prevent ActivateBestChain // from blocking due to queue overrun. - threadGroup.create_thread([&]{ m_node.scheduler->serviceQueue(); }); + threadGroup.create_thread([&] { TraceThread("scheduler", [&] { m_node.scheduler->serviceQueue(); }); }); GetMainSignals().RegisterBackgroundSignalScheduler(*m_node.scheduler); pblocktree.reset(new CBlockTreeDB(1 << 20, true)); @@ -139,7 +146,7 @@ TestingSetup::TestingSetup(const std::string& chainName, const std::vector<const ::ChainstateActive().InitCoinsDB( /* cache_size_bytes */ 1 << 23, /* in_memory */ true, /* should_wipe */ false); assert(!::ChainstateActive().CanFlushToDisk()); - ::ChainstateActive().InitCoinsCache(); + ::ChainstateActive().InitCoinsCache(1 << 23); assert(::ChainstateActive().CanFlushToDisk()); if (!LoadGenesisBlock(chainparams)) { throw std::runtime_error("LoadGenesisBlock failed."); @@ -179,9 +186,9 @@ TestingSetup::~TestingSetup() m_node.connman.reset(); m_node.banman.reset(); m_node.args = nullptr; + UnloadBlockIndex(m_node.mempool); m_node.mempool = nullptr; m_node.scheduler.reset(); - UnloadBlockIndex(); m_node.chainman->Reset(); m_node.chainman = nullptr; pblocktree.reset(); @@ -228,7 +235,7 @@ CBlock TestChain100Setup::CreateAndProcessBlock(const std::vector<CMutableTransa while (!CheckProofOfWork(block.GetHash(), block.nBits, chainparams.GetConsensus())) ++block.nNonce; std::shared_ptr<const CBlock> shared_pblock = std::make_shared<const CBlock>(block); - EnsureChainman(m_node).ProcessNewBlock(chainparams, shared_pblock, true, nullptr); + Assert(m_node.chainman)->ProcessNewBlock(chainparams, shared_pblock, true, nullptr); CBlock result = block; return result; diff --git a/src/test/util/setup_common.h b/src/test/util/setup_common.h index 2477f9ad06..78b279e42a 100644 --- a/src/test/util/setup_common.h +++ b/src/test/util/setup_common.h @@ -11,13 +11,13 @@ #include <node/context.h> #include <pubkey.h> #include <random.h> -#include <scheduler.h> #include <txmempool.h> +#include <util/check.h> #include <util/string.h> #include <type_traits> -#include <boost/thread.hpp> +#include <boost/thread/thread.hpp> /** This is connected to the logger. Can be used to redirect logs to any other log */ extern const std::function<void(const std::string&)> G_TEST_LOG_FUN; diff --git a/src/test/util/transaction_utils.h b/src/test/util/transaction_utils.h index 1beddd334b..6f2faeec6c 100644 --- a/src/test/util/transaction_utils.h +++ b/src/test/util/transaction_utils.h @@ -22,8 +22,8 @@ CMutableTransaction BuildCreditingTransaction(const CScript& scriptPubKey, int n CMutableTransaction BuildSpendingTransaction(const CScript& scriptSig, const CScriptWitness& scriptWitness, const CTransaction& txCredit); // Helper: create two dummy transactions, each with two outputs. -// The first has nValues[0] and nValues[1] outputs paid to a TX_PUBKEY, -// the second nValues[2] and nValues[3] outputs paid to a TX_PUBKEYHASH. +// The first has nValues[0] and nValues[1] outputs paid to a TxoutType::PUBKEY, +// the second nValues[2] and nValues[3] outputs paid to a TxoutType::PUBKEYHASH. std::vector<CMutableTransaction> SetupDummyInputs(FillableSigningProvider& keystoreRet, CCoinsViewCache& coinsRet, const std::array<CAmount,4>& nValues); #endif // BITCOIN_TEST_UTIL_TRANSACTION_UTILS_H diff --git a/src/test/util_tests.cpp b/src/test/util_tests.cpp index cf26ca3adb..b49370c967 100644 --- a/src/test/util_tests.cpp +++ b/src/test/util_tests.cpp @@ -9,6 +9,7 @@ #include <key.h> // For CKey #include <optional.h> #include <sync.h> +#include <test/util/logging.h> #include <test/util/setup_common.h> #include <test/util/str.h> #include <uint256.h> @@ -41,6 +42,16 @@ namespace BCLog { BOOST_FIXTURE_TEST_SUITE(util_tests, BasicTestingSetup) +BOOST_AUTO_TEST_CASE(util_check) +{ + // Check that Assert can forward + const std::unique_ptr<int> p_two = Assert(MakeUnique<int>(2)); + // Check that Assert works on lvalues and rvalues + const int two = *Assert(p_two); + Assert(two == 2); + Assert(true); +} + BOOST_AUTO_TEST_CASE(util_criticalsection) { RecursiveMutex cs; @@ -94,47 +105,24 @@ BOOST_AUTO_TEST_CASE(util_ParseHex) BOOST_AUTO_TEST_CASE(util_HexStr) { BOOST_CHECK_EQUAL( - HexStr(ParseHex_expected, ParseHex_expected + sizeof(ParseHex_expected)), + HexStr(ParseHex_expected), "04678afdb0fe5548271967f1a67130b7105cd6a828e03909a67962e0ea1f61deb649f6bc3f4cef38c4f35504e51ec112de5c384df7ba0b8d578a4c702b6bf11d5f"); BOOST_CHECK_EQUAL( - HexStr(ParseHex_expected + sizeof(ParseHex_expected), - ParseHex_expected + sizeof(ParseHex_expected)), + HexStr(Span<const unsigned char>( + ParseHex_expected + sizeof(ParseHex_expected), + ParseHex_expected + sizeof(ParseHex_expected))), ""); BOOST_CHECK_EQUAL( - HexStr(ParseHex_expected, ParseHex_expected), + HexStr(Span<const unsigned char>(ParseHex_expected, ParseHex_expected)), ""); std::vector<unsigned char> ParseHex_vec(ParseHex_expected, ParseHex_expected + 5); BOOST_CHECK_EQUAL( - HexStr(ParseHex_vec.rbegin(), ParseHex_vec.rend()), - "b0fd8a6704" - ); - - BOOST_CHECK_EQUAL( - HexStr(std::reverse_iterator<const uint8_t *>(ParseHex_expected), - std::reverse_iterator<const uint8_t *>(ParseHex_expected)), - "" - ); - - BOOST_CHECK_EQUAL( - HexStr(std::reverse_iterator<const uint8_t *>(ParseHex_expected + 1), - std::reverse_iterator<const uint8_t *>(ParseHex_expected)), - "04" - ); - - BOOST_CHECK_EQUAL( - HexStr(std::reverse_iterator<const uint8_t *>(ParseHex_expected + 5), - std::reverse_iterator<const uint8_t *>(ParseHex_expected)), - "b0fd8a6704" - ); - - BOOST_CHECK_EQUAL( - HexStr(std::reverse_iterator<const uint8_t *>(ParseHex_expected + 65), - std::reverse_iterator<const uint8_t *>(ParseHex_expected)), - "5f1df16b2b704c8a578d0bbaf74d385cde12c11ee50455f3c438ef4c3fbcf649b6de611feae06279a60939e028a8d65c10b73071a6f16719274855feb0fd8a6704" + HexStr(ParseHex_vec), + "04678afdb0" ); } @@ -998,7 +986,7 @@ BOOST_FIXTURE_TEST_CASE(util_ArgsMerge, ArgsMergeTestingSetup) desc += "\n"; - out_sha.Write((const unsigned char*)desc.data(), desc.size()); + out_sha.Write(MakeUCharSpan(desc)); if (out_file) { BOOST_REQUIRE(fwrite(desc.data(), 1, desc.size(), out_file) == desc.size()); } @@ -1011,7 +999,7 @@ BOOST_FIXTURE_TEST_CASE(util_ArgsMerge, ArgsMergeTestingSetup) unsigned char out_sha_bytes[CSHA256::OUTPUT_SIZE]; out_sha.Finalize(out_sha_bytes); - std::string out_sha_hex = HexStr(std::begin(out_sha_bytes), std::end(out_sha_bytes)); + std::string out_sha_hex = HexStr(out_sha_bytes); // If check below fails, should manually dump the results with: // @@ -1101,7 +1089,7 @@ BOOST_FIXTURE_TEST_CASE(util_ChainMerge, ChainMergeTestingSetup) } desc += "\n"; - out_sha.Write((const unsigned char*)desc.data(), desc.size()); + out_sha.Write(MakeUCharSpan(desc)); if (out_file) { BOOST_REQUIRE(fwrite(desc.data(), 1, desc.size(), out_file) == desc.size()); } @@ -1114,7 +1102,7 @@ BOOST_FIXTURE_TEST_CASE(util_ChainMerge, ChainMergeTestingSetup) unsigned char out_sha_bytes[CSHA256::OUTPUT_SIZE]; out_sha.Finalize(out_sha_bytes); - std::string out_sha_hex = HexStr(std::begin(out_sha_bytes), std::end(out_sha_bytes)); + std::string out_sha_hex = HexStr(out_sha_bytes); // If check below fails, should manually dump the results with: // @@ -1128,6 +1116,28 @@ BOOST_FIXTURE_TEST_CASE(util_ChainMerge, ChainMergeTestingSetup) BOOST_CHECK_EQUAL(out_sha_hex, "f0b3a3c29869edc765d579c928f7f1690a71fbb673b49ccf39cbc4de18156a0d"); } +BOOST_AUTO_TEST_CASE(util_ReadWriteSettings) +{ + // Test writing setting. + TestArgsManager args1; + args1.LockSettings([&](util::Settings& settings) { settings.rw_settings["name"] = "value"; }); + args1.WriteSettingsFile(); + + // Test reading setting. + TestArgsManager args2; + args2.ReadSettingsFile(); + args2.LockSettings([&](util::Settings& settings) { BOOST_CHECK_EQUAL(settings.rw_settings["name"].get_str(), "value"); }); + + // Test error logging, and remove previously written setting. + { + ASSERT_DEBUG_LOG("Failed renaming settings file"); + fs::remove(GetDataDir() / "settings.json"); + fs::create_directory(GetDataDir() / "settings.json"); + args2.WriteSettingsFile(); + fs::remove(GetDataDir() / "settings.json"); + } +} + BOOST_AUTO_TEST_CASE(util_FormatMoney) { BOOST_CHECK_EQUAL(FormatMoney(0), "0.00"); @@ -1829,7 +1839,7 @@ BOOST_AUTO_TEST_CASE(test_spanparsing) // Const(...): parse a constant, update span to skip it if successful input = "MilkToastHoney"; - sp = MakeSpan(input); + sp = input; success = Const("", sp); // empty BOOST_CHECK(success); BOOST_CHECK_EQUAL(SpanToStr(sp), "MilkToastHoney"); @@ -1854,7 +1864,7 @@ BOOST_AUTO_TEST_CASE(test_spanparsing) // Func(...): parse a function call, update span to argument if successful input = "Foo(Bar(xy,z()))"; - sp = MakeSpan(input); + sp = input; success = Func("FooBar", sp); BOOST_CHECK(!success); @@ -1877,31 +1887,31 @@ BOOST_AUTO_TEST_CASE(test_spanparsing) Span<const char> result; input = "(n*(n-1))/2"; - sp = MakeSpan(input); + sp = input; result = Expr(sp); BOOST_CHECK_EQUAL(SpanToStr(result), "(n*(n-1))/2"); BOOST_CHECK_EQUAL(SpanToStr(sp), ""); input = "foo,bar"; - sp = MakeSpan(input); + sp = input; result = Expr(sp); BOOST_CHECK_EQUAL(SpanToStr(result), "foo"); BOOST_CHECK_EQUAL(SpanToStr(sp), ",bar"); input = "(aaaaa,bbbbb()),c"; - sp = MakeSpan(input); + sp = input; result = Expr(sp); BOOST_CHECK_EQUAL(SpanToStr(result), "(aaaaa,bbbbb())"); BOOST_CHECK_EQUAL(SpanToStr(sp), ",c"); input = "xyz)foo"; - sp = MakeSpan(input); + sp = input; result = Expr(sp); BOOST_CHECK_EQUAL(SpanToStr(result), "xyz"); BOOST_CHECK_EQUAL(SpanToStr(sp), ")foo"); input = "((a),(b),(c)),xxx"; - sp = MakeSpan(input); + sp = input; result = Expr(sp); BOOST_CHECK_EQUAL(SpanToStr(result), "((a),(b),(c))"); BOOST_CHECK_EQUAL(SpanToStr(sp), ",xxx"); @@ -1910,7 +1920,7 @@ BOOST_AUTO_TEST_CASE(test_spanparsing) std::vector<Span<const char>> results; input = "xxx"; - results = Split(MakeSpan(input), 'x'); + results = Split(input, 'x'); BOOST_CHECK_EQUAL(results.size(), 4U); BOOST_CHECK_EQUAL(SpanToStr(results[0]), ""); BOOST_CHECK_EQUAL(SpanToStr(results[1]), ""); @@ -1918,19 +1928,19 @@ BOOST_AUTO_TEST_CASE(test_spanparsing) BOOST_CHECK_EQUAL(SpanToStr(results[3]), ""); input = "one#two#three"; - results = Split(MakeSpan(input), '-'); + results = Split(input, '-'); BOOST_CHECK_EQUAL(results.size(), 1U); BOOST_CHECK_EQUAL(SpanToStr(results[0]), "one#two#three"); input = "one#two#three"; - results = Split(MakeSpan(input), '#'); + results = Split(input, '#'); BOOST_CHECK_EQUAL(results.size(), 3U); BOOST_CHECK_EQUAL(SpanToStr(results[0]), "one"); BOOST_CHECK_EQUAL(SpanToStr(results[1]), "two"); BOOST_CHECK_EQUAL(SpanToStr(results[2]), "three"); input = "*foo*bar*"; - results = Split(MakeSpan(input), '*'); + results = Split(input, '*'); BOOST_CHECK_EQUAL(results.size(), 4U); BOOST_CHECK_EQUAL(SpanToStr(results[0]), ""); BOOST_CHECK_EQUAL(SpanToStr(results[1]), "foo"); @@ -2153,8 +2163,8 @@ BOOST_AUTO_TEST_CASE(message_hash) std::string(1, (char)unsigned_tx.length()) + unsigned_tx; - const uint256 signature_hash = Hash(unsigned_tx.begin(), unsigned_tx.end()); - const uint256 message_hash1 = Hash(prefixed_message.begin(), prefixed_message.end()); + const uint256 signature_hash = Hash(unsigned_tx); + const uint256 message_hash1 = Hash(prefixed_message); const uint256 message_hash2 = MessageHash(unsigned_tx); BOOST_CHECK_EQUAL(message_hash1, message_hash2); diff --git a/src/test/util_threadnames_tests.cpp b/src/test/util_threadnames_tests.cpp index 4dcc080b2d..f3f9fb2bff 100644 --- a/src/test/util_threadnames_tests.cpp +++ b/src/test/util_threadnames_tests.cpp @@ -53,8 +53,6 @@ std::set<std::string> RenameEnMasse(int num_threads) */ BOOST_AUTO_TEST_CASE(util_threadnames_test_rename_threaded) { - BOOST_CHECK_EQUAL(util::ThreadGetInternalName(), ""); - #if !defined(HAVE_THREAD_LOCAL) // This test doesn't apply to platforms where we don't have thread_local. return; diff --git a/src/test/validation_block_tests.cpp b/src/test/validation_block_tests.cpp index 45e0c5484e..8e85b7df3e 100644 --- a/src/test/validation_block_tests.cpp +++ b/src/test/validation_block_tests.cpp @@ -163,10 +163,10 @@ BOOST_AUTO_TEST_CASE(processnewblock_signals_ordering) std::transform(blocks.begin(), blocks.end(), std::back_inserter(headers), [](std::shared_ptr<const CBlock> b) { return b->GetBlockHeader(); }); // Process all the headers so we understand the toplogy of the chain - BOOST_CHECK(EnsureChainman(m_node).ProcessNewBlockHeaders(headers, state, Params())); + BOOST_CHECK(Assert(m_node.chainman)->ProcessNewBlockHeaders(headers, state, Params())); // Connect the genesis block and drain any outstanding events - BOOST_CHECK(EnsureChainman(m_node).ProcessNewBlock(Params(), std::make_shared<CBlock>(Params().GenesisBlock()), true, &ignored)); + BOOST_CHECK(Assert(m_node.chainman)->ProcessNewBlock(Params(), std::make_shared<CBlock>(Params().GenesisBlock()), true, &ignored)); SyncWithValidationInterfaceQueue(); // subscribe to events (this subscriber will validate event ordering) @@ -188,13 +188,13 @@ BOOST_AUTO_TEST_CASE(processnewblock_signals_ordering) FastRandomContext insecure; for (int i = 0; i < 1000; i++) { auto block = blocks[insecure.randrange(blocks.size() - 1)]; - EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, &ignored); + Assert(m_node.chainman)->ProcessNewBlock(Params(), block, true, &ignored); } // to make sure that eventually we process the full chain - do it here for (auto block : blocks) { if (block->vtx.size() == 1) { - bool processed = EnsureChainman(m_node).ProcessNewBlock(Params(), block, true, &ignored); + bool processed = Assert(m_node.chainman)->ProcessNewBlock(Params(), block, true, &ignored); assert(processed); } } @@ -233,7 +233,7 @@ BOOST_AUTO_TEST_CASE(mempool_locks_reorg) { bool ignored; auto ProcessBlock = [&](std::shared_ptr<const CBlock> block) -> bool { - return EnsureChainman(m_node).ProcessNewBlock(Params(), block, /* fForceProcessing */ true, /* fNewBlock */ &ignored); + return Assert(m_node.chainman)->ProcessNewBlock(Params(), block, /* fForceProcessing */ true, /* fNewBlock */ &ignored); }; // Process all mined blocks diff --git a/src/test/validation_chainstate_tests.cpp b/src/test/validation_chainstate_tests.cpp new file mode 100644 index 0000000000..2076a1096a --- /dev/null +++ b/src/test/validation_chainstate_tests.cpp @@ -0,0 +1,74 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. +// +#include <random.h> +#include <uint256.h> +#include <consensus/validation.h> +#include <sync.h> +#include <test/util/setup_common.h> +#include <validation.h> + +#include <vector> + +#include <boost/test/unit_test.hpp> + +BOOST_FIXTURE_TEST_SUITE(validation_chainstate_tests, TestingSetup) + +//! Test resizing coins-related CChainState caches during runtime. +//! +BOOST_AUTO_TEST_CASE(validation_chainstate_resize_caches) +{ + ChainstateManager manager; + + //! Create and add a Coin with DynamicMemoryUsage of 80 bytes to the given view. + auto add_coin = [](CCoinsViewCache& coins_view) -> COutPoint { + Coin newcoin; + uint256 txid = InsecureRand256(); + COutPoint outp{txid, 0}; + newcoin.nHeight = 1; + newcoin.out.nValue = InsecureRand32(); + newcoin.out.scriptPubKey.assign((uint32_t)56, 1); + coins_view.AddCoin(outp, std::move(newcoin), false); + + return outp; + }; + + CChainState& c1 = *WITH_LOCK(cs_main, return &manager.InitializeChainstate()); + c1.InitCoinsDB( + /* cache_size_bytes */ 1 << 23, /* in_memory */ true, /* should_wipe */ false); + WITH_LOCK(::cs_main, c1.InitCoinsCache(1 << 23)); + + // Add a coin to the in-memory cache, upsize once, then downsize. + { + LOCK(::cs_main); + auto outpoint = add_coin(c1.CoinsTip()); + + // Set a meaningless bestblock value in the coinsview cache - otherwise we won't + // flush during ResizecoinsCaches() and will subsequently hit an assertion. + c1.CoinsTip().SetBestBlock(InsecureRand256()); + + BOOST_CHECK(c1.CoinsTip().HaveCoinInCache(outpoint)); + + c1.ResizeCoinsCaches( + 1 << 24, // upsizing the coinsview cache + 1 << 22 // downsizing the coinsdb cache + ); + + // View should still have the coin cached, since we haven't destructed the cache on upsize. + BOOST_CHECK(c1.CoinsTip().HaveCoinInCache(outpoint)); + + c1.ResizeCoinsCaches( + 1 << 22, // downsizing the coinsview cache + 1 << 23 // upsizing the coinsdb cache + ); + + // The view cache should be empty since we had to destruct to downsize. + BOOST_CHECK(!c1.CoinsTip().HaveCoinInCache(outpoint)); + } + + // Avoid triggering the address sanitizer. + WITH_LOCK(::cs_main, manager.Unload()); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/validation_chainstatemanager_tests.cpp b/src/test/validation_chainstatemanager_tests.cpp index 0d149285ad..887a48124f 100644 --- a/src/test/validation_chainstatemanager_tests.cpp +++ b/src/test/validation_chainstatemanager_tests.cpp @@ -28,13 +28,11 @@ BOOST_AUTO_TEST_CASE(chainstatemanager) // Create a legacy (IBD) chainstate. // - ENTER_CRITICAL_SECTION(cs_main); - CChainState& c1 = manager.InitializeChainstate(); - LEAVE_CRITICAL_SECTION(cs_main); + CChainState& c1 = *WITH_LOCK(::cs_main, return &manager.InitializeChainstate()); chainstates.push_back(&c1); c1.InitCoinsDB( /* cache_size_bytes */ 1 << 23, /* in_memory */ true, /* should_wipe */ false); - WITH_LOCK(::cs_main, c1.InitCoinsCache()); + WITH_LOCK(::cs_main, c1.InitCoinsCache(1 << 23)); BOOST_CHECK(!manager.IsSnapshotActive()); BOOST_CHECK(!manager.IsSnapshotValidated()); @@ -56,13 +54,11 @@ BOOST_AUTO_TEST_CASE(chainstatemanager) // Create a snapshot-based chainstate. // - ENTER_CRITICAL_SECTION(cs_main); - CChainState& c2 = manager.InitializeChainstate(GetRandHash()); - LEAVE_CRITICAL_SECTION(cs_main); + CChainState& c2 = *WITH_LOCK(::cs_main, return &manager.InitializeChainstate(GetRandHash())); chainstates.push_back(&c2); c2.InitCoinsDB( /* cache_size_bytes */ 1 << 23, /* in_memory */ true, /* should_wipe */ false); - WITH_LOCK(::cs_main, c2.InitCoinsCache()); + WITH_LOCK(::cs_main, c2.InitCoinsCache(1 << 23)); // Unlike c1, which doesn't have any blocks. Gets us different tip, height. c2.LoadGenesisBlock(chainparams); BlockValidationState _; @@ -104,4 +100,54 @@ BOOST_AUTO_TEST_CASE(chainstatemanager) WITH_LOCK(::cs_main, manager.Unload()); } +//! Test rebalancing the caches associated with each chainstate. +BOOST_AUTO_TEST_CASE(chainstatemanager_rebalance_caches) +{ + ChainstateManager manager; + size_t max_cache = 10000; + manager.m_total_coinsdb_cache = max_cache; + manager.m_total_coinstip_cache = max_cache; + + std::vector<CChainState*> chainstates; + + // Create a legacy (IBD) chainstate. + // + CChainState& c1 = *WITH_LOCK(cs_main, return &manager.InitializeChainstate()); + chainstates.push_back(&c1); + c1.InitCoinsDB( + /* cache_size_bytes */ 1 << 23, /* in_memory */ true, /* should_wipe */ false); + + { + LOCK(::cs_main); + c1.InitCoinsCache(1 << 23); + c1.CoinsTip().SetBestBlock(InsecureRand256()); + manager.MaybeRebalanceCaches(); + } + + BOOST_CHECK_EQUAL(c1.m_coinstip_cache_size_bytes, max_cache); + BOOST_CHECK_EQUAL(c1.m_coinsdb_cache_size_bytes, max_cache); + + // Create a snapshot-based chainstate. + // + CChainState& c2 = *WITH_LOCK(cs_main, return &manager.InitializeChainstate(GetRandHash())); + chainstates.push_back(&c2); + c2.InitCoinsDB( + /* cache_size_bytes */ 1 << 23, /* in_memory */ true, /* should_wipe */ false); + + { + LOCK(::cs_main); + c2.InitCoinsCache(1 << 23); + c2.CoinsTip().SetBestBlock(InsecureRand256()); + manager.MaybeRebalanceCaches(); + } + + // Since both chainstates are considered to be in initial block download, + // the snapshot chainstate should take priority. + BOOST_CHECK_CLOSE(c1.m_coinstip_cache_size_bytes, max_cache * 0.05, 1); + BOOST_CHECK_CLOSE(c1.m_coinsdb_cache_size_bytes, max_cache * 0.05, 1); + BOOST_CHECK_CLOSE(c2.m_coinstip_cache_size_bytes, max_cache * 0.95, 1); + BOOST_CHECK_CLOSE(c2.m_coinsdb_cache_size_bytes, max_cache * 0.95, 1); + +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/validation_flush_tests.cpp b/src/test/validation_flush_tests.cpp index a863e3a4d5..8bac914f05 100644 --- a/src/test/validation_flush_tests.cpp +++ b/src/test/validation_flush_tests.cpp @@ -21,7 +21,7 @@ BOOST_AUTO_TEST_CASE(getcoinscachesizestate) BlockManager blockman{}; CChainState chainstate{blockman}; chainstate.InitCoinsDB(/*cache_size_bytes*/ 1 << 10, /*in_memory*/ true, /*should_wipe*/ false); - WITH_LOCK(::cs_main, chainstate.InitCoinsCache()); + WITH_LOCK(::cs_main, chainstate.InitCoinsCache(1 << 10)); CTxMemPool tx_pool{}; constexpr bool is_64_bit = sizeof(void*) == 8; @@ -56,7 +56,7 @@ BOOST_AUTO_TEST_CASE(getcoinscachesizestate) // Without any coins in the cache, we shouldn't need to flush. BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0), + chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0), CoinsCacheSizeState::OK); // If the initial memory allocations of cacheCoins don't match these common @@ -71,7 +71,7 @@ BOOST_AUTO_TEST_CASE(getcoinscachesizestate) } BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0), + chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0), CoinsCacheSizeState::CRITICAL); BOOST_TEST_MESSAGE("Exiting cache flush tests early due to unsupported arch"); @@ -92,7 +92,7 @@ BOOST_AUTO_TEST_CASE(getcoinscachesizestate) print_view_mem_usage(view); BOOST_CHECK_EQUAL(view.AccessCoin(res).DynamicMemoryUsage(), COIN_SIZE); BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0), + chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0), CoinsCacheSizeState::OK); } @@ -100,26 +100,26 @@ BOOST_AUTO_TEST_CASE(getcoinscachesizestate) for (int i{0}; i < 4; ++i) { add_coin(view); print_view_mem_usage(view); - if (chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0) == + if (chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0) == CoinsCacheSizeState::CRITICAL) { break; } } BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0), + chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 0), CoinsCacheSizeState::CRITICAL); // Passing non-zero max mempool usage should allow us more headroom. BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 1 << 10), + chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 1 << 10), CoinsCacheSizeState::OK); for (int i{0}; i < 3; ++i) { add_coin(view); print_view_mem_usage(view); BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 1 << 10), + chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, /*max_mempool_size_bytes*/ 1 << 10), CoinsCacheSizeState::OK); } @@ -135,7 +135,7 @@ BOOST_AUTO_TEST_CASE(getcoinscachesizestate) BOOST_CHECK(usage_percentage >= 0.9); BOOST_CHECK(usage_percentage < 1); BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, 1 << 10), + chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, 1 << 10), CoinsCacheSizeState::LARGE); } @@ -143,7 +143,7 @@ BOOST_AUTO_TEST_CASE(getcoinscachesizestate) for (int i{0}; i < 1000; ++i) { add_coin(view); BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool), + chainstate.GetCoinsCacheSizeState(&tx_pool), CoinsCacheSizeState::OK); } @@ -151,7 +151,7 @@ BOOST_AUTO_TEST_CASE(getcoinscachesizestate) // preallocated memory that doesn't get reclaimed even after flush. BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, 0), + chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, 0), CoinsCacheSizeState::CRITICAL); view.SetBestBlock(InsecureRand256()); @@ -159,7 +159,7 @@ BOOST_AUTO_TEST_CASE(getcoinscachesizestate) print_view_mem_usage(view); BOOST_CHECK_EQUAL( - chainstate.GetCoinsCacheSizeState(tx_pool, MAX_COINS_CACHE_BYTES, 0), + chainstate.GetCoinsCacheSizeState(&tx_pool, MAX_COINS_CACHE_BYTES, 0), CoinsCacheSizeState::CRITICAL); } diff --git a/src/threadsafety.h b/src/threadsafety.h index 942aa3fdcd..5f2c40bac6 100644 --- a/src/threadsafety.h +++ b/src/threadsafety.h @@ -60,6 +60,13 @@ // and should only be used when sync.h Mutex/LOCK/etc are not usable. class LOCKABLE StdMutex : public std::mutex { +public: +#ifdef __clang__ + //! For negative capabilities in the Clang Thread Safety Analysis. + //! A negative requirement uses the EXCLUSIVE_LOCKS_REQUIRED attribute, in conjunction + //! with the ! operator, to indicate that a mutex should not be held. + const StdMutex& operator!() const { return *this; } +#endif // __clang__ }; // StdLockGuard provides an annotated version of std::lock_guard for us, diff --git a/src/timedata.cpp b/src/timedata.cpp index dd56acf44f..6b3a79017b 100644 --- a/src/timedata.cpp +++ b/src/timedata.cpp @@ -9,15 +9,14 @@ #include <timedata.h> #include <netaddress.h> +#include <node/ui_interface.h> #include <sync.h> -#include <ui_interface.h> #include <util/system.h> #include <util/translation.h> #include <warnings.h> - -static RecursiveMutex cs_nTimeOffset; -static int64_t nTimeOffset GUARDED_BY(cs_nTimeOffset) = 0; +static Mutex g_timeoffset_mutex; +static int64_t nTimeOffset GUARDED_BY(g_timeoffset_mutex) = 0; /** * "Never go to sea with two chronometers; take one or three." @@ -28,7 +27,7 @@ static int64_t nTimeOffset GUARDED_BY(cs_nTimeOffset) = 0; */ int64_t GetTimeOffset() { - LOCK(cs_nTimeOffset); + LOCK(g_timeoffset_mutex); return nTimeOffset; } @@ -46,7 +45,7 @@ static int64_t abs64(int64_t n) void AddTimeData(const CNetAddr& ip, int64_t nOffsetSample) { - LOCK(cs_nTimeOffset); + LOCK(g_timeoffset_mutex); // Ignore duplicates static std::set<CNetAddr> setKnown; if (setKnown.size() == BITCOIN_TIMEDATA_MAX_SAMPLES) @@ -57,7 +56,7 @@ void AddTimeData(const CNetAddr& ip, int64_t nOffsetSample) // Add data static CMedianFilter<int64_t> vTimeOffsets(BITCOIN_TIMEDATA_MAX_SAMPLES, 0); vTimeOffsets.input(nOffsetSample); - LogPrint(BCLog::NET,"added time data, samples %d, offset %+d (%+d minutes)\n", vTimeOffsets.size(), nOffsetSample, nOffsetSample/60); + LogPrint(BCLog::NET, "added time data, samples %d, offset %+d (%+d minutes)\n", vTimeOffsets.size(), nOffsetSample, nOffsetSample / 60); // There is a known issue here (see issue #4521): // @@ -76,33 +75,27 @@ void AddTimeData(const CNetAddr& ip, int64_t nOffsetSample) // So we should hold off on fixing this and clean it up as part of // a timing cleanup that strengthens it in a number of other ways. // - if (vTimeOffsets.size() >= 5 && vTimeOffsets.size() % 2 == 1) - { + if (vTimeOffsets.size() >= 5 && vTimeOffsets.size() % 2 == 1) { int64_t nMedian = vTimeOffsets.median(); std::vector<int64_t> vSorted = vTimeOffsets.sorted(); // Only let other nodes change our time by so much - if (abs64(nMedian) <= std::max<int64_t>(0, gArgs.GetArg("-maxtimeadjustment", DEFAULT_MAX_TIME_ADJUSTMENT))) - { + if (abs64(nMedian) <= std::max<int64_t>(0, gArgs.GetArg("-maxtimeadjustment", DEFAULT_MAX_TIME_ADJUSTMENT))) { nTimeOffset = nMedian; - } - else - { + } else { nTimeOffset = 0; static bool fDone; - if (!fDone) - { + if (!fDone) { // If nobody has a time different than ours but within 5 minutes of ours, give a warning bool fMatch = false; - for (const int64_t nOffset : vSorted) - if (nOffset != 0 && abs64(nOffset) < 5 * 60) - fMatch = true; + for (const int64_t nOffset : vSorted) { + if (nOffset != 0 && abs64(nOffset) < 5 * 60) fMatch = true; + } - if (!fMatch) - { + if (!fMatch) { fDone = true; bilingual_str strMessage = strprintf(_("Please check that your computer's date and time are correct! If your clock is wrong, %s will not work properly."), PACKAGE_NAME); - SetMiscWarning(strMessage.translated); + SetMiscWarning(strMessage); uiInterface.ThreadSafeMessageBox(strMessage, "", CClientUIInterface::MSG_WARNING); } } @@ -113,8 +106,7 @@ void AddTimeData(const CNetAddr& ip, int64_t nOffsetSample) LogPrint(BCLog::NET, "%+d ", n); /* Continued */ } LogPrint(BCLog::NET, "| "); /* Continued */ - - LogPrint(BCLog::NET, "nTimeOffset = %+d (%+d minutes)\n", nTimeOffset, nTimeOffset/60); + LogPrint(BCLog::NET, "nTimeOffset = %+d (%+d minutes)\n", nTimeOffset, nTimeOffset / 60); } } } diff --git a/src/torcontrol.cpp b/src/torcontrol.cpp index 84118b36ef..5d56d1ff89 100644 --- a/src/torcontrol.cpp +++ b/src/torcontrol.cpp @@ -405,7 +405,7 @@ static bool WriteBinaryFile(const fs::path &filename, const std::string &data) /****** Bitcoin specific TorController implementation ********/ /** Controller that connects to Tor control socket, authenticate, then create - * and maintain an ephemeral hidden service. + * and maintain an ephemeral onion service. */ class TorController { @@ -534,7 +534,7 @@ void TorController::auth_cb(TorControlConnection& _conn, const TorControlReply& // Finally - now create the service if (private_key.empty()) // No private key, generate one private_key = "NEW:RSA1024"; // Explicitly request RSA1024 - see issue #9214 - // Request hidden service, redirect port. + // Request onion service, redirect port. // Note that the 'virtual' port is always the default port to avoid decloaking nodes using other ports. _conn.Command(strprintf("ADD_ONION %s Port=%i,127.0.0.1:%i", private_key, Params().GetDefaultPort(), GetListenPort()), std::bind(&TorController::add_onion_cb, this, std::placeholders::_1, std::placeholders::_2)); diff --git a/src/txdb.cpp b/src/txdb.cpp index 129697f0e7..72460e7c69 100644 --- a/src/txdb.cpp +++ b/src/txdb.cpp @@ -5,19 +5,18 @@ #include <txdb.h> +#include <node/ui_interface.h> #include <pow.h> #include <random.h> #include <shutdown.h> -#include <ui_interface.h> #include <uint256.h> +#include <util/memory.h> #include <util/system.h> #include <util/translation.h> #include <util/vector.h> #include <stdint.h> -#include <boost/thread.hpp> - static const char DB_COIN = 'C'; static const char DB_COINS = 'c'; static const char DB_BLOCK_FILES = 'f'; @@ -41,35 +40,45 @@ struct CoinEntry { } -CCoinsViewDB::CCoinsViewDB(fs::path ldb_path, size_t nCacheSize, bool fMemory, bool fWipe) : db(ldb_path, nCacheSize, fMemory, fWipe, true) +CCoinsViewDB::CCoinsViewDB(fs::path ldb_path, size_t nCacheSize, bool fMemory, bool fWipe) : + m_db(MakeUnique<CDBWrapper>(ldb_path, nCacheSize, fMemory, fWipe, true)), + m_ldb_path(ldb_path), + m_is_memory(fMemory) { } + +void CCoinsViewDB::ResizeCache(size_t new_cache_size) { + // Have to do a reset first to get the original `m_db` state to release its + // filesystem lock. + m_db.reset(); + m_db = MakeUnique<CDBWrapper>( + m_ldb_path, new_cache_size, m_is_memory, /*fWipe*/ false, /*obfuscate*/ true); } bool CCoinsViewDB::GetCoin(const COutPoint &outpoint, Coin &coin) const { - return db.Read(CoinEntry(&outpoint), coin); + return m_db->Read(CoinEntry(&outpoint), coin); } bool CCoinsViewDB::HaveCoin(const COutPoint &outpoint) const { - return db.Exists(CoinEntry(&outpoint)); + return m_db->Exists(CoinEntry(&outpoint)); } uint256 CCoinsViewDB::GetBestBlock() const { uint256 hashBestChain; - if (!db.Read(DB_BEST_BLOCK, hashBestChain)) + if (!m_db->Read(DB_BEST_BLOCK, hashBestChain)) return uint256(); return hashBestChain; } std::vector<uint256> CCoinsViewDB::GetHeadBlocks() const { std::vector<uint256> vhashHeadBlocks; - if (!db.Read(DB_HEAD_BLOCKS, vhashHeadBlocks)) { + if (!m_db->Read(DB_HEAD_BLOCKS, vhashHeadBlocks)) { return std::vector<uint256>(); } return vhashHeadBlocks; } bool CCoinsViewDB::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashBlock) { - CDBBatch batch(db); + CDBBatch batch(*m_db); size_t count = 0; size_t changed = 0; size_t batch_size = (size_t)gArgs.GetArg("-dbbatchsize", nDefaultDbBatchSize); @@ -107,7 +116,7 @@ bool CCoinsViewDB::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashBlock) { mapCoins.erase(itOld); if (batch.SizeEstimate() > batch_size) { LogPrint(BCLog::COINDB, "Writing partial batch of %.2f MiB\n", batch.SizeEstimate() * (1.0 / 1048576.0)); - db.WriteBatch(batch); + m_db->WriteBatch(batch); batch.Clear(); if (crash_simulate) { static FastRandomContext rng; @@ -124,14 +133,14 @@ bool CCoinsViewDB::BatchWrite(CCoinsMap &mapCoins, const uint256 &hashBlock) { batch.Write(DB_BEST_BLOCK, hashBlock); LogPrint(BCLog::COINDB, "Writing final batch of %.2f MiB\n", batch.SizeEstimate() * (1.0 / 1048576.0)); - bool ret = db.WriteBatch(batch); + bool ret = m_db->WriteBatch(batch); LogPrint(BCLog::COINDB, "Committed %u changed transaction outputs (out of %u) to coin database...\n", (unsigned int)changed, (unsigned int)count); return ret; } size_t CCoinsViewDB::EstimateSize() const { - return db.EstimateSize(DB_COIN, (char)(DB_COIN+1)); + return m_db->EstimateSize(DB_COIN, (char)(DB_COIN+1)); } CBlockTreeDB::CBlockTreeDB(size_t nCacheSize, bool fMemory, bool fWipe) : CDBWrapper(GetDataDir() / "blocks" / "index", nCacheSize, fMemory, fWipe) { @@ -158,7 +167,7 @@ bool CBlockTreeDB::ReadLastBlockFile(int &nFile) { CCoinsViewCursor *CCoinsViewDB::Cursor() const { - CCoinsViewDBCursor *i = new CCoinsViewDBCursor(const_cast<CDBWrapper&>(db).NewIterator(), GetBestBlock()); + CCoinsViewDBCursor *i = new CCoinsViewDBCursor(const_cast<CDBWrapper&>(*m_db).NewIterator(), GetBestBlock()); /* It seems that there are no "const iterators" for LevelDB. Since we only need read operations on it, use a const-cast to get around that restriction. */ @@ -242,7 +251,6 @@ bool CBlockTreeDB::LoadBlockIndexGuts(const Consensus::Params& consensusParams, // Load m_block_index while (pcursor->Valid()) { - boost::this_thread::interruption_point(); if (ShutdownRequested()) return false; std::pair<char, uint256> key; if (pcursor->GetKey(key) && key.first == DB_BLOCK_INDEX) { @@ -338,7 +346,7 @@ public: * Currently implemented: from the per-tx utxo model (0.8..0.14.x) to per-txout. */ bool CCoinsViewDB::Upgrade() { - std::unique_ptr<CDBIterator> pcursor(db.NewIterator()); + std::unique_ptr<CDBIterator> pcursor(m_db->NewIterator()); pcursor->Seek(std::make_pair(DB_COINS, uint256())); if (!pcursor->Valid()) { return true; @@ -349,12 +357,11 @@ bool CCoinsViewDB::Upgrade() { LogPrintf("[0%%]..."); /* Continued */ uiInterface.ShowProgress(_("Upgrading UTXO database").translated, 0, true); size_t batch_size = 1 << 24; - CDBBatch batch(db); + CDBBatch batch(*m_db); int reportDone = 0; std::pair<unsigned char, uint256> key; std::pair<unsigned char, uint256> prev_key = {DB_COINS, uint256()}; while (pcursor->Valid()) { - boost::this_thread::interruption_point(); if (ShutdownRequested()) { break; } @@ -384,9 +391,9 @@ bool CCoinsViewDB::Upgrade() { } batch.Erase(key); if (batch.SizeEstimate() > batch_size) { - db.WriteBatch(batch); + m_db->WriteBatch(batch); batch.Clear(); - db.CompactRange(prev_key, key); + m_db->CompactRange(prev_key, key); prev_key = key; } pcursor->Next(); @@ -394,8 +401,8 @@ bool CCoinsViewDB::Upgrade() { break; } } - db.WriteBatch(batch); - db.CompactRange({DB_COINS, uint256()}, key); + m_db->WriteBatch(batch); + m_db->CompactRange({DB_COINS, uint256()}, key); uiInterface.ShowProgress("", 100, false); LogPrintf("[%s].\n", ShutdownRequested() ? "CANCELLED" : "DONE"); return !ShutdownRequested(); diff --git a/src/txdb.h b/src/txdb.h index 488c24f935..0cf7e2f1b8 100644 --- a/src/txdb.h +++ b/src/txdb.h @@ -39,11 +39,16 @@ static const int64_t max_filter_index_cache = 1024; //! Max memory allocated to coin DB specific cache (MiB) static const int64_t nMaxCoinsDBCache = 8; +// Actually declared in validation.cpp; can't include because of circular dependency. +extern RecursiveMutex cs_main; + /** CCoinsView backed by the coin database (chainstate/) */ class CCoinsViewDB final : public CCoinsView { protected: - CDBWrapper db; + std::unique_ptr<CDBWrapper> m_db; + fs::path m_ldb_path; + bool m_is_memory; public: /** * @param[in] ldb_path Location in the filesystem where leveldb data will be stored. @@ -60,6 +65,9 @@ public: //! Attempt to update from an older database format. Returns whether an error occurred. bool Upgrade(); size_t EstimateSize() const override; + + //! Dynamically alter the underlying leveldb cache size. + void ResizeCache(size_t new_cache_size) EXCLUSIVE_LOCKS_REQUIRED(cs_main); }; /** Specialization of CCoinsViewCursor to iterate over a CCoinsViewDB */ diff --git a/src/txmempool.cpp b/src/txmempool.cpp index c5c0208d8f..de1a3ec68f 100644 --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -410,7 +410,7 @@ void CTxMemPool::removeUnchecked(txiter it, MemPoolRemovalReason reason) // for any reason except being included in a block. Clients interested // in transactions included in blocks can subscribe to the BlockConnected // notification. - GetMainSignals().TransactionRemovedFromMempool(it->GetSharedTx()); + GetMainSignals().TransactionRemovedFromMempool(it->GetSharedTx(), reason); } const uint256 hash = it->GetTx().GetHash(); @@ -726,12 +726,12 @@ void CTxMemPool::check(const CCoinsViewCache *pcoins) const assert(innerUsage == cachedInnerUsage); } -bool CTxMemPool::CompareDepthAndScore(const uint256& hasha, const uint256& hashb) +bool CTxMemPool::CompareDepthAndScore(const uint256& hasha, const uint256& hashb, bool wtxid) { LOCK(cs); - indexed_transaction_set::const_iterator i = mapTx.find(hasha); + indexed_transaction_set::const_iterator i = wtxid ? get_iter_from_wtxid(hasha) : mapTx.find(hasha); if (i == mapTx.end()) return false; - indexed_transaction_set::const_iterator j = mapTx.find(hashb); + indexed_transaction_set::const_iterator j = wtxid ? get_iter_from_wtxid(hashb) : mapTx.find(hashb); if (j == mapTx.end()) return true; uint64_t counta = i->GetCountWithAncestors(); uint64_t countb = j->GetCountWithAncestors(); @@ -811,15 +811,17 @@ CTransactionRef CTxMemPool::get(const uint256& hash) const return i->GetSharedTx(); } -TxMempoolInfo CTxMemPool::info(const uint256& hash) const +TxMempoolInfo CTxMemPool::info(const GenTxid& gtxid) const { LOCK(cs); - indexed_transaction_set::const_iterator i = mapTx.find(hash); + indexed_transaction_set::const_iterator i = (gtxid.IsWtxid() ? get_iter_from_wtxid(gtxid.GetHash()) : mapTx.find(gtxid.GetHash())); if (i == mapTx.end()) return TxMempoolInfo(); return GetInfo(i); } +TxMempoolInfo CTxMemPool::info(const uint256& txid) const { return info(GenTxid{false, txid}); } + void CTxMemPool::PrioritiseTransaction(const uint256& hash, const CAmount& nFeeDelta) { { @@ -917,8 +919,8 @@ bool CCoinsViewMemPool::GetCoin(const COutPoint &outpoint, Coin &coin) const { size_t CTxMemPool::DynamicMemoryUsage() const { LOCK(cs); - // Estimate the overhead of mapTx to be 12 pointers + an allocation, as no exact formula for boost::multi_index_contained is implemented. - return memusage::MallocUsage(sizeof(CTxMemPoolEntry) + 12 * sizeof(void*)) * mapTx.size() + memusage::DynamicUsage(mapNextTx) + memusage::DynamicUsage(mapDeltas) + memusage::DynamicUsage(mapLinks) + memusage::DynamicUsage(vTxHashes) + cachedInnerUsage; + // Estimate the overhead of mapTx to be 15 pointers + an allocation, as no exact formula for boost::multi_index_contained is implemented. + return memusage::MallocUsage(sizeof(CTxMemPoolEntry) + 15 * sizeof(void*)) * mapTx.size() + memusage::DynamicUsage(mapNextTx) + memusage::DynamicUsage(mapDeltas) + memusage::DynamicUsage(mapLinks) + memusage::DynamicUsage(vTxHashes) + cachedInnerUsage; } void CTxMemPool::RemoveUnbroadcastTx(const uint256& txid, const bool unchecked) { diff --git a/src/txmempool.h b/src/txmempool.h index 583f7614b7..4743e1b63a 100644 --- a/src/txmempool.h +++ b/src/txmempool.h @@ -198,6 +198,22 @@ struct mempoolentry_txid } }; +// extracts a transaction witness-hash from CTxMemPoolEntry or CTransactionRef +struct mempoolentry_wtxid +{ + typedef uint256 result_type; + result_type operator() (const CTxMemPoolEntry &entry) const + { + return entry.GetTx().GetWitnessHash(); + } + + result_type operator() (const CTransactionRef& tx) const + { + return tx->GetWitnessHash(); + } +}; + + /** \class CompareTxMemPoolEntryByDescendantScore * * Sort an entry by max(score/size of entry's tx, score/size with all descendants). @@ -318,6 +334,7 @@ public: struct descendant_score {}; struct entry_time {}; struct ancestor_score {}; +struct index_by_wtxid {}; class CBlockPolicyEstimator; @@ -383,8 +400,9 @@ public: * * CTxMemPool::mapTx, and CTxMemPoolEntry bookkeeping: * - * mapTx is a boost::multi_index that sorts the mempool on 4 criteria: - * - transaction hash + * mapTx is a boost::multi_index that sorts the mempool on 5 criteria: + * - transaction hash (txid) + * - witness-transaction hash (wtxid) * - descendant feerate [we use max(feerate of tx, feerate of tx with all descendants)] * - time in mempool * - ancestor feerate [we use min(feerate of tx, feerate of tx with all unconfirmed ancestors)] @@ -469,6 +487,12 @@ public: boost::multi_index::indexed_by< // sorted by txid boost::multi_index::hashed_unique<mempoolentry_txid, SaltedTxidHasher>, + // sorted by wtxid + boost::multi_index::hashed_unique< + boost::multi_index::tag<index_by_wtxid>, + mempoolentry_wtxid, + SaltedTxidHasher + >, // sorted by fee rate boost::multi_index::ordered_non_unique< boost::multi_index::tag<descendant_score>, @@ -549,8 +573,11 @@ private: std::vector<indexed_transaction_set::const_iterator> GetSortedDepthAndScore() const EXCLUSIVE_LOCKS_REQUIRED(cs); - /** track locally submitted transactions to periodically retry initial broadcast */ - std::set<uint256> m_unbroadcast_txids GUARDED_BY(cs); + /** + * track locally submitted transactions to periodically retry initial broadcast + * map of txid -> wtxid + */ + std::map<uint256, uint256> m_unbroadcast_txids GUARDED_BY(cs); public: indirectmap<COutPoint, const CTransaction*> mapNextTx GUARDED_BY(cs); @@ -586,7 +613,7 @@ public: void clear(); void _clear() EXCLUSIVE_LOCKS_REQUIRED(cs); //lock free - bool CompareDepthAndScore(const uint256& hasha, const uint256& hashb); + bool CompareDepthAndScore(const uint256& hasha, const uint256& hashb, bool wtxid=false); void queryHashes(std::vector<uint256>& vtxid) const; bool isSpent(const COutPoint& outpoint) const; unsigned int GetTransactionsUpdated() const; @@ -689,24 +716,34 @@ public: return totalTxSize; } - bool exists(const uint256& hash) const + bool exists(const GenTxid& gtxid) const { LOCK(cs); - return (mapTx.count(hash) != 0); + if (gtxid.IsWtxid()) { + return (mapTx.get<index_by_wtxid>().count(gtxid.GetHash()) != 0); + } + return (mapTx.count(gtxid.GetHash()) != 0); } + bool exists(const uint256& txid) const { return exists(GenTxid{false, txid}); } CTransactionRef get(const uint256& hash) const; + txiter get_iter_from_wtxid(const uint256& wtxid) const EXCLUSIVE_LOCKS_REQUIRED(cs) + { + AssertLockHeld(cs); + return mapTx.project<0>(mapTx.get<index_by_wtxid>().find(wtxid)); + } TxMempoolInfo info(const uint256& hash) const; + TxMempoolInfo info(const GenTxid& gtxid) const; std::vector<TxMempoolInfo> infoAll() const; size_t DynamicMemoryUsage() const; /** Adds a transaction to the unbroadcast set */ - void AddUnbroadcastTx(const uint256& txid) { + void AddUnbroadcastTx(const uint256& txid, const uint256& wtxid) { LOCK(cs); // Sanity Check: the transaction should also be in the mempool if (exists(txid)) { - m_unbroadcast_txids.insert(txid); + m_unbroadcast_txids[txid] = wtxid; } } @@ -714,7 +751,7 @@ public: void RemoveUnbroadcastTx(const uint256& txid, const bool unchecked = false); /** Returns transactions in unbroadcast set */ - std::set<uint256> GetUnbroadcastTxs() const { + std::map<uint256, uint256> GetUnbroadcastTxs() const { LOCK(cs); return m_unbroadcast_txids; } diff --git a/src/uint256.cpp b/src/uint256.cpp index a943e71062..ee1b34eadd 100644 --- a/src/uint256.cpp +++ b/src/uint256.cpp @@ -12,20 +12,24 @@ template <unsigned int BITS> base_blob<BITS>::base_blob(const std::vector<unsigned char>& vch) { - assert(vch.size() == sizeof(data)); - memcpy(data, vch.data(), sizeof(data)); + assert(vch.size() == sizeof(m_data)); + memcpy(m_data, vch.data(), sizeof(m_data)); } template <unsigned int BITS> std::string base_blob<BITS>::GetHex() const { - return HexStr(std::reverse_iterator<const uint8_t*>(data + sizeof(data)), std::reverse_iterator<const uint8_t*>(data)); + uint8_t m_data_rev[WIDTH]; + for (int i = 0; i < WIDTH; ++i) { + m_data_rev[i] = m_data[WIDTH - 1 - i]; + } + return HexStr(m_data_rev); } template <unsigned int BITS> void base_blob<BITS>::SetHex(const char* psz) { - memset(data, 0, sizeof(data)); + memset(m_data, 0, sizeof(m_data)); // skip leading spaces while (IsSpace(*psz)) @@ -39,7 +43,7 @@ void base_blob<BITS>::SetHex(const char* psz) size_t digits = 0; while (::HexDigit(psz[digits]) != -1) digits++; - unsigned char* p1 = (unsigned char*)data; + unsigned char* p1 = (unsigned char*)m_data; unsigned char* pend = p1 + WIDTH; while (digits > 0 && p1 < pend) { *p1 = ::HexDigit(psz[--digits]); diff --git a/src/uint256.h b/src/uint256.h index b36598f572..8ab747ef49 100644 --- a/src/uint256.h +++ b/src/uint256.h @@ -18,11 +18,11 @@ class base_blob { protected: static constexpr int WIDTH = BITS / 8; - uint8_t data[WIDTH]; + uint8_t m_data[WIDTH]; public: base_blob() { - memset(data, 0, sizeof(data)); + memset(m_data, 0, sizeof(m_data)); } explicit base_blob(const std::vector<unsigned char>& vch); @@ -30,17 +30,17 @@ public: bool IsNull() const { for (int i = 0; i < WIDTH; i++) - if (data[i] != 0) + if (m_data[i] != 0) return false; return true; } void SetNull() { - memset(data, 0, sizeof(data)); + memset(m_data, 0, sizeof(m_data)); } - inline int Compare(const base_blob& other) const { return memcmp(data, other.data, sizeof(data)); } + inline int Compare(const base_blob& other) const { return memcmp(m_data, other.m_data, sizeof(m_data)); } friend inline bool operator==(const base_blob& a, const base_blob& b) { return a.Compare(b) == 0; } friend inline bool operator!=(const base_blob& a, const base_blob& b) { return a.Compare(b) != 0; } @@ -51,34 +51,37 @@ public: void SetHex(const std::string& str); std::string ToString() const; + const unsigned char* data() const { return m_data; } + unsigned char* data() { return m_data; } + unsigned char* begin() { - return &data[0]; + return &m_data[0]; } unsigned char* end() { - return &data[WIDTH]; + return &m_data[WIDTH]; } const unsigned char* begin() const { - return &data[0]; + return &m_data[0]; } const unsigned char* end() const { - return &data[WIDTH]; + return &m_data[WIDTH]; } unsigned int size() const { - return sizeof(data); + return sizeof(m_data); } uint64_t GetUint64(int pos) const { - const uint8_t* ptr = data + pos * 8; + const uint8_t* ptr = m_data + pos * 8; return ((uint64_t)ptr[0]) | \ ((uint64_t)ptr[1]) << 8 | \ ((uint64_t)ptr[2]) << 16 | \ @@ -92,13 +95,13 @@ public: template<typename Stream> void Serialize(Stream& s) const { - s.write((char*)data, sizeof(data)); + s.write((char*)m_data, sizeof(m_data)); } template<typename Stream> void Unserialize(Stream& s) { - s.read((char*)data, sizeof(data)); + s.read((char*)m_data, sizeof(m_data)); } }; diff --git a/src/util/check.h b/src/util/check.h index 5c0f32cf51..9edf394492 100644 --- a/src/util/check.h +++ b/src/util/check.h @@ -25,7 +25,7 @@ class NonFatalCheckError : public std::runtime_error * - where the condition is assumed to be true, not for error handling or validating user input * - where a failure to fulfill the condition is recoverable and does not abort the program * - * For example in RPC code, where it is undersirable to crash the whole program, this can be generally used to replace + * For example in RPC code, where it is undesirable to crash the whole program, this can be generally used to replace * asserts or recoverable logic errors. A NonFatalCheckError in RPC code is caught and passed as a string to the RPC * caller, which can then report the issue to the developers. */ @@ -42,4 +42,18 @@ class NonFatalCheckError : public std::runtime_error } \ } while (false) +#if defined(NDEBUG) +#error "Cannot compile without assertions!" +#endif + +/** Helper for Assert(). TODO remove in C++14 and replace `decltype(get_pure_r_value(val))` with `T` (templated lambda) */ +template <typename T> +T get_pure_r_value(T&& val) +{ + return std::forward<T>(val); +} + +/** Identity function. Abort if the value compares equal to zero */ +#define Assert(val) [&]() -> decltype(get_pure_r_value(val)) { auto&& check = (val); assert(#val && check); return std::forward<decltype(get_pure_r_value(val))>(check); }() + #endif // BITCOIN_UTIL_CHECK_H diff --git a/src/util/error.cpp b/src/util/error.cpp index 72a6e87cde..3e29083712 100644 --- a/src/util/error.cpp +++ b/src/util/error.cpp @@ -8,37 +8,37 @@ #include <util/system.h> #include <util/translation.h> -std::string TransactionErrorString(const TransactionError err) +bilingual_str TransactionErrorString(const TransactionError err) { switch (err) { case TransactionError::OK: - return "No error"; + return Untranslated("No error"); case TransactionError::MISSING_INPUTS: - return "Missing inputs"; + return Untranslated("Inputs missing or spent"); case TransactionError::ALREADY_IN_CHAIN: - return "Transaction already in block chain"; + return Untranslated("Transaction already in block chain"); case TransactionError::P2P_DISABLED: - return "Peer-to-peer functionality missing or disabled"; + return Untranslated("Peer-to-peer functionality missing or disabled"); case TransactionError::MEMPOOL_REJECTED: - return "Transaction rejected by AcceptToMemoryPool"; + return Untranslated("Transaction rejected by AcceptToMemoryPool"); case TransactionError::MEMPOOL_ERROR: - return "AcceptToMemoryPool failed"; + return Untranslated("AcceptToMemoryPool failed"); case TransactionError::INVALID_PSBT: - return "PSBT is not sane"; + return Untranslated("PSBT is not well-formed"); case TransactionError::PSBT_MISMATCH: - return "PSBTs not compatible (different transactions)"; + return Untranslated("PSBTs not compatible (different transactions)"); case TransactionError::SIGHASH_MISMATCH: - return "Specified sighash value does not match existing value"; + return Untranslated("Specified sighash value does not match value stored in PSBT"); case TransactionError::MAX_FEE_EXCEEDED: - return "Fee exceeds maximum configured by -maxtxfee"; + return Untranslated("Fee exceeds maximum configured by -maxtxfee"); // no default case, so the compiler can warn about missing cases } assert(false); } -std::string ResolveErrMsg(const std::string& optname, const std::string& strBind) +bilingual_str ResolveErrMsg(const std::string& optname, const std::string& strBind) { - return strprintf(_("Cannot resolve -%s address: '%s'").translated, optname, strBind); + return strprintf(_("Cannot resolve -%s address: '%s'"), optname, strBind); } bilingual_str AmountHighWarn(const std::string& optname) diff --git a/src/util/error.h b/src/util/error.h index 61af88ddea..b9830c9eea 100644 --- a/src/util/error.h +++ b/src/util/error.h @@ -32,9 +32,9 @@ enum class TransactionError { MAX_FEE_EXCEEDED, }; -std::string TransactionErrorString(const TransactionError error); +bilingual_str TransactionErrorString(const TransactionError error); -std::string ResolveErrMsg(const std::string& optname, const std::string& strBind); +bilingual_str ResolveErrMsg(const std::string& optname, const std::string& strBind); bilingual_str AmountHighWarn(const std::string& optname); diff --git a/src/util/fees.cpp b/src/util/fees.cpp index b335bfa666..6208a20a97 100644 --- a/src/util/fees.cpp +++ b/src/util/fees.cpp @@ -6,11 +6,16 @@ #include <util/fees.h> #include <policy/fees.h> +#include <util/strencodings.h> +#include <util/string.h> #include <map> #include <string> +#include <vector> +#include <utility> -std::string StringForFeeReason(FeeReason reason) { +std::string StringForFeeReason(FeeReason reason) +{ static const std::map<FeeReason, std::string> fee_reason_strings = { {FeeReason::NONE, "None"}, {FeeReason::HALF_ESTIMATE, "Half Target 60% Threshold"}, @@ -29,16 +34,31 @@ std::string StringForFeeReason(FeeReason reason) { return reason_string->second; } -bool FeeModeFromString(const std::string& mode_string, FeeEstimateMode& fee_estimate_mode) { - static const std::map<std::string, FeeEstimateMode> fee_modes = { - {"UNSET", FeeEstimateMode::UNSET}, - {"ECONOMICAL", FeeEstimateMode::ECONOMICAL}, - {"CONSERVATIVE", FeeEstimateMode::CONSERVATIVE}, +const std::vector<std::pair<std::string, FeeEstimateMode>>& FeeModeMap() +{ + static const std::vector<std::pair<std::string, FeeEstimateMode>> FEE_MODES = { + {"unset", FeeEstimateMode::UNSET}, + {"economical", FeeEstimateMode::ECONOMICAL}, + {"conservative", FeeEstimateMode::CONSERVATIVE}, + {(CURRENCY_UNIT + "/kB"), FeeEstimateMode::BTC_KB}, + {(CURRENCY_ATOM + "/B"), FeeEstimateMode::SAT_B}, }; - auto mode = fee_modes.find(mode_string); + return FEE_MODES; +} - if (mode == fee_modes.end()) return false; +std::string FeeModes(const std::string& delimiter) +{ + return Join(FeeModeMap(), delimiter, [&](const std::pair<std::string, FeeEstimateMode>& i) { return i.first; }); +} - fee_estimate_mode = mode->second; - return true; +bool FeeModeFromString(const std::string& mode_string, FeeEstimateMode& fee_estimate_mode) +{ + auto searchkey = ToUpper(mode_string); + for (const auto& pair : FeeModeMap()) { + if (ToUpper(pair.first) == searchkey) { + fee_estimate_mode = pair.second; + return true; + } + } + return false; } diff --git a/src/util/fees.h b/src/util/fees.h index a930c8935a..d52046a44c 100644 --- a/src/util/fees.h +++ b/src/util/fees.h @@ -12,5 +12,6 @@ enum class FeeReason; bool FeeModeFromString(const std::string& mode_string, FeeEstimateMode& fee_estimate_mode); std::string StringForFeeReason(FeeReason reason); +std::string FeeModes(const std::string& delimiter); #endif // BITCOIN_UTIL_FEES_H diff --git a/src/util/settings.cpp b/src/util/settings.cpp index e4979df755..b92b1d30c3 100644 --- a/src/util/settings.cpp +++ b/src/util/settings.cpp @@ -4,6 +4,7 @@ #include <util/settings.h> +#include <tinyformat.h> #include <univalue.h> namespace util { @@ -12,12 +13,13 @@ namespace { enum class Source { FORCED, COMMAND_LINE, + RW_SETTINGS, CONFIG_FILE_NETWORK_SECTION, CONFIG_FILE_DEFAULT_SECTION }; //! Merge settings from multiple sources in precedence order: -//! Forced config > command line > config file network-specific section > config file default section +//! Forced config > command line > read-write settings file > config file network-specific section > config file default section //! //! This function is provided with a callback function fn that contains //! specific logic for how to merge the sources. @@ -32,6 +34,10 @@ static void MergeSettings(const Settings& settings, const std::string& section, if (auto* values = FindKey(settings.command_line_options, name)) { fn(SettingsSpan(*values), Source::COMMAND_LINE); } + // Merge in the read-write settings + if (const SettingsValue* value = FindKey(settings.rw_settings, name)) { + fn(SettingsSpan(*value), Source::RW_SETTINGS); + } // Merge in the network-specific section of the config file if (!section.empty()) { if (auto* map = FindKey(settings.ro_config, section)) { @@ -49,6 +55,62 @@ static void MergeSettings(const Settings& settings, const std::string& section, } } // namespace +bool ReadSettings(const fs::path& path, std::map<std::string, SettingsValue>& values, std::vector<std::string>& errors) +{ + values.clear(); + errors.clear(); + + fsbridge::ifstream file; + file.open(path); + if (!file.is_open()) return true; // Ok for file not to exist. + + SettingsValue in; + if (!in.read(std::string{std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()})) { + errors.emplace_back(strprintf("Unable to parse settings file %s", path.string())); + return false; + } + + if (file.fail()) { + errors.emplace_back(strprintf("Failed reading settings file %s", path.string())); + return false; + } + file.close(); // Done with file descriptor. Release while copying data. + + if (!in.isObject()) { + errors.emplace_back(strprintf("Found non-object value %s in settings file %s", in.write(), path.string())); + return false; + } + + const std::vector<std::string>& in_keys = in.getKeys(); + const std::vector<SettingsValue>& in_values = in.getValues(); + for (size_t i = 0; i < in_keys.size(); ++i) { + auto inserted = values.emplace(in_keys[i], in_values[i]); + if (!inserted.second) { + errors.emplace_back(strprintf("Found duplicate key %s in settings file %s", in_keys[i], path.string())); + } + } + return errors.empty(); +} + +bool WriteSettings(const fs::path& path, + const std::map<std::string, SettingsValue>& values, + std::vector<std::string>& errors) +{ + SettingsValue out(SettingsValue::VOBJ); + for (const auto& value : values) { + out.__pushKV(value.first, value.second); + } + fsbridge::ofstream file; + file.open(path); + if (file.fail()) { + errors.emplace_back(strprintf("Error: Unable to open settings file %s for writing", path.string())); + return false; + } + file << out.write(/* prettyIndent= */ 1, /* indentLevel= */ 4) << std::endl; + file.close(); + return true; +} + SettingsValue GetSetting(const Settings& settings, const std::string& section, const std::string& name, diff --git a/src/util/settings.h b/src/util/settings.h index 1d03639fa2..ed36349232 100644 --- a/src/util/settings.h +++ b/src/util/settings.h @@ -5,6 +5,8 @@ #ifndef BITCOIN_UTIL_SETTINGS_H #define BITCOIN_UTIL_SETTINGS_H +#include <fs.h> + #include <map> #include <string> #include <vector> @@ -24,19 +26,31 @@ namespace util { //! https://github.com/bitcoin/bitcoin/pull/15934/files#r337691812) using SettingsValue = UniValue; -//! Stored bitcoin settings. This struct combines settings from the command line -//! and a read-only configuration file. +//! Stored settings. This struct combines settings from the command line, a +//! read-only configuration file, and a read-write runtime settings file. struct Settings { //! Map of setting name to forced setting value. std::map<std::string, SettingsValue> forced_settings; //! Map of setting name to list of command line values. std::map<std::string, std::vector<SettingsValue>> command_line_options; + //! Map of setting name to read-write file setting value. + std::map<std::string, SettingsValue> rw_settings; //! Map of config section name and setting name to list of config file values. std::map<std::string, std::map<std::string, std::vector<SettingsValue>>> ro_config; }; +//! Read settings file. +bool ReadSettings(const fs::path& path, + std::map<std::string, SettingsValue>& values, + std::vector<std::string>& errors); + +//! Write settings file. +bool WriteSettings(const fs::path& path, + const std::map<std::string, SettingsValue>& values, + std::vector<std::string>& errors); + //! Get settings value from combined sources: forced settings, command line -//! arguments and the read-only config file. +//! arguments, runtime read-write settings, and the read-only config file. //! //! @param ignore_default_section_config - ignore values in the default section //! of the config file (part before any diff --git a/src/util/strencodings.cpp b/src/util/strencodings.cpp index 3a903b6897..d10f92ffe6 100644 --- a/src/util/strencodings.cpp +++ b/src/util/strencodings.cpp @@ -569,3 +569,16 @@ std::string Capitalize(std::string str) str[0] = ToUpper(str.front()); return str; } + +std::string HexStr(const Span<const uint8_t> s) +{ + std::string rv; + static constexpr char hexmap[16] = { '0', '1', '2', '3', '4', '5', '6', '7', + '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' }; + rv.reserve(s.size() * 2); + for (uint8_t v: s) { + rv.push_back(hexmap[v >> 4]); + rv.push_back(hexmap[v & 15]); + } + return rv; +} diff --git a/src/util/strencodings.h b/src/util/strencodings.h index bd988f1410..eaa0fa9992 100644 --- a/src/util/strencodings.h +++ b/src/util/strencodings.h @@ -10,6 +10,7 @@ #define BITCOIN_UTIL_STRENCODINGS_H #include <attributes.h> +#include <span.h> #include <cstdint> #include <iterator> @@ -119,27 +120,11 @@ NODISCARD bool ParseUInt64(const std::string& str, uint64_t *out); */ NODISCARD bool ParseDouble(const std::string& str, double *out); -template<typename T> -std::string HexStr(const T itbegin, const T itend) -{ - std::string rv; - static const char hexmap[16] = { '0', '1', '2', '3', '4', '5', '6', '7', - '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' }; - rv.reserve(std::distance(itbegin, itend) * 2); - for(T it = itbegin; it < itend; ++it) - { - unsigned char val = (unsigned char)(*it); - rv.push_back(hexmap[val>>4]); - rv.push_back(hexmap[val&15]); - } - return rv; -} - -template<typename T> -inline std::string HexStr(const T& vch) -{ - return HexStr(vch.begin(), vch.end()); -} +/** + * Convert a span of bytes to a lower-case hexadecimal string. + */ +std::string HexStr(const Span<const uint8_t> s); +inline std::string HexStr(const Span<const char> s) { return HexStr(MakeUCharSpan(s)); } /** * Format a paragraph of text to a fixed width, adding spaces for diff --git a/src/util/system.cpp b/src/util/system.cpp index bde0f097be..7b74789b32 100644 --- a/src/util/system.cpp +++ b/src/util/system.cpp @@ -6,6 +6,10 @@ #include <sync.h> #include <util/system.h> +#ifdef HAVE_BOOST_PROCESS +#include <boost/process.hpp> +#endif // HAVE_BOOST_PROCESS + #include <chainparamsbase.h> #include <util/strencodings.h> #include <util/string.h> @@ -73,6 +77,7 @@ const int64_t nStartupTime = GetTime(); const char * const BITCOIN_CONF_FILENAME = "bitcoin.conf"; +const char * const BITCOIN_SETTINGS_FILENAME = "settings.json"; ArgsManager gArgs; @@ -372,6 +377,84 @@ bool ArgsManager::IsArgSet(const std::string& strArg) const return !GetSetting(strArg).isNull(); } +bool ArgsManager::InitSettings(std::string& error) +{ + if (!GetSettingsPath()) { + return true; // Do nothing if settings file disabled. + } + + std::vector<std::string> errors; + if (!ReadSettingsFile(&errors)) { + error = strprintf("Failed loading settings file:\n- %s\n", Join(errors, "\n- ")); + return false; + } + if (!WriteSettingsFile(&errors)) { + error = strprintf("Failed saving settings file:\n- %s\n", Join(errors, "\n- ")); + return false; + } + return true; +} + +bool ArgsManager::GetSettingsPath(fs::path* filepath, bool temp) const +{ + if (IsArgNegated("-settings")) { + return false; + } + if (filepath) { + std::string settings = GetArg("-settings", BITCOIN_SETTINGS_FILENAME); + *filepath = fs::absolute(temp ? settings + ".tmp" : settings, GetDataDir(/* net_specific= */ true)); + } + return true; +} + +static void SaveErrors(const std::vector<std::string> errors, std::vector<std::string>* error_out) +{ + for (const auto& error : errors) { + if (error_out) { + error_out->emplace_back(error); + } else { + LogPrintf("%s\n", error); + } + } +} + +bool ArgsManager::ReadSettingsFile(std::vector<std::string>* errors) +{ + fs::path path; + if (!GetSettingsPath(&path, /* temp= */ false)) { + return true; // Do nothing if settings file disabled. + } + + LOCK(cs_args); + m_settings.rw_settings.clear(); + std::vector<std::string> read_errors; + if (!util::ReadSettings(path, m_settings.rw_settings, read_errors)) { + SaveErrors(read_errors, errors); + return false; + } + return true; +} + +bool ArgsManager::WriteSettingsFile(std::vector<std::string>* errors) const +{ + fs::path path, path_tmp; + if (!GetSettingsPath(&path, /* temp= */ false) || !GetSettingsPath(&path_tmp, /* temp= */ true)) { + throw std::logic_error("Attempt to write settings file when dynamic settings are disabled."); + } + + LOCK(cs_args); + std::vector<std::string> write_errors; + if (!util::WriteSettings(path_tmp, m_settings.rw_settings, write_errors)) { + SaveErrors(write_errors, errors); + return false; + } + if (!RenameOver(path_tmp, path)) { + SaveErrors({strprintf("Failed renaming settings file %s to %s\n", path_tmp.string(), path.string())}, errors); + return false; + } + return true; +} + bool ArgsManager::IsArgNegated(const std::string& strArg) const { return GetSetting(strArg).isFalse(); @@ -893,6 +976,9 @@ void ArgsManager::LogArgs() const for (const auto& section : m_settings.ro_config) { logArgsPrefix("Config file arg:", section.first, section.second); } + for (const auto& setting : m_settings.rw_settings) { + LogPrintf("Setting file arg: %s = %s\n", setting.first, setting.second.write()); + } logArgsPrefix("Command-line arg:", "", m_settings.command_line_options); } @@ -939,7 +1025,7 @@ bool FileCommit(FILE *file) return false; } #else - #if defined(__linux__) || defined(__NetBSD__) + #if defined(HAVE_FDATASYNC) if (fdatasync(fileno(file)) != 0 && errno != EINVAL) { // Ignore EINVAL for filesystems that don't support sync LogPrintf("%s: fdatasync failed: %d\n", __func__, errno); return false; @@ -1079,6 +1165,43 @@ void runCommand(const std::string& strCommand) } #endif +#ifdef HAVE_BOOST_PROCESS +UniValue RunCommandParseJSON(const std::string& str_command, const std::string& str_std_in) +{ + namespace bp = boost::process; + + UniValue result_json; + bp::opstream stdin_stream; + bp::ipstream stdout_stream; + bp::ipstream stderr_stream; + + if (str_command.empty()) return UniValue::VNULL; + + bp::child c( + str_command, + bp::std_out > stdout_stream, + bp::std_err > stderr_stream, + bp::std_in < stdin_stream + ); + if (!str_std_in.empty()) { + stdin_stream << str_std_in << std::endl; + } + stdin_stream.pipe().close(); + + std::string result; + std::string error; + std::getline(stdout_stream, result); + std::getline(stderr_stream, error); + + c.wait(); + const int n_error = c.exit_code(); + if (n_error) throw std::runtime_error(strprintf("RunCommandParseJSON error: process(%s) returned %d: %s\n", str_command, n_error, error)); + if (!result_json.read(result)) throw std::runtime_error("Unable to parse JSON: " + result); + + return result_json; +} +#endif // HAVE_BOOST_PROCESS + void SetupEnvironment() { #ifdef HAVE_MALLOPT_ARENA_MAX @@ -1163,8 +1286,9 @@ void ScheduleBatchPriority() { #ifdef SCHED_BATCH const static sched_param param{}; - if (pthread_setschedparam(pthread_self(), SCHED_BATCH, ¶m) != 0) { - LogPrintf("Failed to pthread_setschedparam: %s\n", strerror(errno)); + const int rc = pthread_setschedparam(pthread_self(), SCHED_BATCH, ¶m); + if (rc != 0) { + LogPrintf("Failed to pthread_setschedparam: %s\n", strerror(rc)); } #endif } diff --git a/src/util/system.h b/src/util/system.h index a5eea5dfab..1df194ca84 100644 --- a/src/util/system.h +++ b/src/util/system.h @@ -37,10 +37,13 @@ #include <boost/thread/condition_variable.hpp> // for boost::thread_interrupted +class UniValue; + // Application startup time (used for uptime calculation) int64_t GetStartupTime(); extern const char * const BITCOIN_CONF_FILENAME; +extern const char * const BITCOIN_SETTINGS_FILENAME; void SetupEnvironment(); bool SetupNetworking(); @@ -95,6 +98,16 @@ std::string ShellEscape(const std::string& arg); #if HAVE_SYSTEM void runCommand(const std::string& strCommand); #endif +#ifdef HAVE_BOOST_PROCESS +/** + * Execute a command which returns JSON, and parse the result. + * + * @param str_command The command to execute, including any arguments + * @param str_std_in string to pass to stdin + * @return parsed JSON + */ +UniValue RunCommandParseJSON(const std::string& str_command, const std::string& str_std_in=""); +#endif // HAVE_BOOST_PROCESS /** * Most paths passed as configuration arguments are treated as relative to @@ -334,6 +347,39 @@ public: Optional<unsigned int> GetArgFlags(const std::string& name) const; /** + * Read and update settings file with saved settings. This needs to be + * called after SelectParams() because the settings file location is + * network-specific. + */ + bool InitSettings(std::string& error); + + /** + * Get settings file path, or return false if read-write settings were + * disabled with -nosettings. + */ + bool GetSettingsPath(fs::path* filepath = nullptr, bool temp = false) const; + + /** + * Read settings file. Push errors to vector, or log them if null. + */ + bool ReadSettingsFile(std::vector<std::string>* errors = nullptr); + + /** + * Write settings file. Push errors to vector, or log them if null. + */ + bool WriteSettingsFile(std::vector<std::string>* errors = nullptr) const; + + /** + * Access settings with lock held. + */ + template <typename Fn> + void LockSettings(Fn&& fn) + { + LOCK(cs_args); + fn(m_settings); + } + + /** * Log the config file options and the command line arguments, * useful for troubleshooting. */ diff --git a/src/util/time.h b/src/util/time.h index b00c25f67c..af934e423b 100644 --- a/src/util/time.h +++ b/src/util/time.h @@ -15,10 +15,15 @@ void UninterruptibleSleep(const std::chrono::microseconds& n); /** * Helper to count the seconds of a duration. * - * All durations should be using std::chrono and calling this should generally be avoided in code. Though, it is still - * preferred to an inline t.count() to protect against a reliance on the exact type of t. + * All durations should be using std::chrono and calling this should generally + * be avoided in code. Though, it is still preferred to an inline t.count() to + * protect against a reliance on the exact type of t. + * + * This helper is used to convert durations before passing them over an + * interface that doesn't support std::chrono (e.g. RPC, debug log, or the GUI) */ inline int64_t count_seconds(std::chrono::seconds t) { return t.count(); } +inline int64_t count_microseconds(std::chrono::microseconds t) { return t.count(); } /** * DEPRECATED diff --git a/src/util/translation.h b/src/util/translation.h index 268bcf30a7..695d6dac96 100644 --- a/src/util/translation.h +++ b/src/util/translation.h @@ -23,6 +23,11 @@ struct bilingual_str { translated += rhs.translated; return *this; } + + bool empty() const + { + return original.empty(); + } }; inline bilingual_str operator+(bilingual_str lhs, const bilingual_str& rhs) diff --git a/src/util/ui_change_type.h b/src/util/ui_change_type.h new file mode 100644 index 0000000000..1db761a18d --- /dev/null +++ b/src/util/ui_change_type.h @@ -0,0 +1,15 @@ +// Copyright (c) 2012-2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_UTIL_UI_CHANGE_TYPE_H +#define BITCOIN_UTIL_UI_CHANGE_TYPE_H + +/** General change type (added, updated, removed). */ +enum ChangeType { + CT_NEW, + CT_UPDATED, + CT_DELETED +}; + +#endif // BITCOIN_UTIL_UI_CHANGE_TYPE_H diff --git a/src/validation.cpp b/src/validation.cpp index 396fc0a1b5..cf2f9dde62 100644 --- a/src/validation.cpp +++ b/src/validation.cpp @@ -20,6 +20,7 @@ #include <index/txindex.h> #include <logging.h> #include <logging/timer.h> +#include <node/ui_interface.h> #include <optional.h> #include <policy/fees.h> #include <policy/policy.h> @@ -36,9 +37,9 @@ #include <tinyformat.h> #include <txdb.h> #include <txmempool.h> -#include <ui_interface.h> #include <uint256.h> #include <undo.h> +#include <util/check.h> // For NDEBUG compile time check #include <util/moneystr.h> #include <util/rbf.h> #include <util/strencodings.h> @@ -50,11 +51,6 @@ #include <string> #include <boost/algorithm/string/replace.hpp> -#include <boost/thread.hpp> - -#if defined(NDEBUG) -# error "Bitcoin cannot be compiled without assertions." -#endif #define MICRO 0.000001 #define MILLI 0.001 @@ -71,12 +67,20 @@ static const unsigned int MAX_DISCONNECTED_TX_POOL_SIZE = 20000; static const unsigned int BLOCKFILE_CHUNK_SIZE = 0x1000000; // 16 MiB /** The pre-allocation chunk size for rev?????.dat files (since 0.8) */ static const unsigned int UNDOFILE_CHUNK_SIZE = 0x100000; // 1 MiB -/** Time to wait (in seconds) between writing blocks/block index to disk. */ -static const unsigned int DATABASE_WRITE_INTERVAL = 60 * 60; -/** Time to wait (in seconds) between flushing chainstate to disk. */ -static const unsigned int DATABASE_FLUSH_INTERVAL = 24 * 60 * 60; -/** Maximum age of our tip in seconds for us to be considered current for fee estimation */ -static const int64_t MAX_FEE_ESTIMATION_TIP_AGE = 3 * 60 * 60; +/** Time to wait between writing blocks/block index to disk. */ +static constexpr std::chrono::hours DATABASE_WRITE_INTERVAL{1}; +/** Time to wait between flushing chainstate to disk. */ +static constexpr std::chrono::hours DATABASE_FLUSH_INTERVAL{24}; +/** Maximum age of our tip for us to be considered current for fee estimation */ +static constexpr std::chrono::hours MAX_FEE_ESTIMATION_TIP_AGE{3}; +const std::vector<std::string> CHECKLEVEL_DOC { + "level 0 reads the blocks from disk", + "level 1 verifies block validity", + "level 2 verifies undo data", + "level 3 checks disconnection of tip blocks", + "level 4 tries to reconnect the blocks", + "each level includes the checks of the previous levels", +}; bool CBlockIndexWorkComparator::operator()(const CBlockIndex *pa, const CBlockIndex *pb) const { // First sort by most total work, ... @@ -135,7 +139,6 @@ bool fPruneMode = false; bool fRequireStandard = true; bool fCheckBlockIndex = false; bool fCheckpointsEnabled = DEFAULT_CHECKPOINTS_ENABLED; -size_t nCoinCacheUsage = 5000 * 300; uint64_t nPruneTarget = 0; int64_t nMaxTipAge = DEFAULT_MAX_TIP_AGE; @@ -295,7 +298,7 @@ bool CheckSequenceLocks(const CTxMemPool& pool, const CTransaction& tx, int flag prevheights[txinIndex] = coin.nHeight; } } - lockPair = CalculateSequenceLocks(tx, flags, &prevheights, index); + lockPair = CalculateSequenceLocks(tx, flags, prevheights, index); if (lp) { lp->height = lockPair.first; lp->time = lockPair.second; @@ -347,7 +350,7 @@ static bool IsCurrentForFeeEstimation() EXCLUSIVE_LOCKS_REQUIRED(cs_main) AssertLockHeld(cs_main); if (::ChainstateActive().IsInitialBlockDownload()) return false; - if (::ChainActive().Tip()->GetBlockTime() < (GetTime() - MAX_FEE_ESTIMATION_TIP_AGE)) + if (::ChainActive().Tip()->GetBlockTime() < count_seconds(GetTime<std::chrono::seconds>() - MAX_FEE_ESTIMATION_TIP_AGE)) return false; if (::ChainActive().Height() < pindexBestHeader->nHeight - 1) return false; @@ -569,8 +572,9 @@ bool MemPoolAccept::PreChecks(ATMPArgs& args, Workspace& ws) CAmount& nConflictingFees = ws.m_conflicting_fees; size_t& nConflictingSize = ws.m_conflicting_size; - if (!CheckTransaction(tx, state)) + if (!CheckTransaction(tx, state)) { return false; // state filled in by CheckTransaction + } // Coinbase is only valid in a block, not as a loose transaction if (tx.IsCoinBase()) @@ -680,12 +684,13 @@ bool MemPoolAccept::PreChecks(ATMPArgs& args, Workspace& ws) CAmount nFees = 0; if (!Consensus::CheckTxInputs(tx, state, m_view, GetSpendHeight(m_view), nFees)) { - return error("%s: Consensus::CheckTxInputs: %s, %s", __func__, tx.GetHash().ToString(), state.ToString()); + return false; // state filled in by CheckTxInputs } // Check for non-standard pay-to-script-hash in inputs - if (fRequireStandard && !AreInputsStandard(tx, m_view)) - return state.Invalid(TxValidationResult::TX_NOT_STANDARD, "bad-txns-nonstandard-inputs"); + if (fRequireStandard && !AreInputsStandard(tx, m_view)) { + return state.Invalid(TxValidationResult::TX_INPUTS_NOT_STANDARD, "bad-txns-nonstandard-inputs"); + } // Check for non-standard witness in P2WSH if (tx.HasWitness() && fRequireStandard && !IsWitnessStandard(tx, m_view)) @@ -934,7 +939,7 @@ bool MemPoolAccept::PolicyScriptChecks(ATMPArgs& args, Workspace& ws, Precompute if (!tx.HasWitness() && CheckInputScripts(tx, state_dummy, m_view, scriptVerifyFlags & ~(SCRIPT_VERIFY_WITNESS | SCRIPT_VERIFY_CLEANSTACK), true, false, txdata) && !CheckInputScripts(tx, state_dummy, m_view, scriptVerifyFlags & ~SCRIPT_VERIFY_CLEANSTACK, true, false, txdata)) { // Only the witness is missing, so the transaction itself may be fine. - state.Invalid(TxValidationResult::TX_WITNESS_MUTATED, + state.Invalid(TxValidationResult::TX_WITNESS_STRIPPED, state.GetRejectReason(), state.GetDebugMessage()); } return false; // state filled in by CheckInputScripts @@ -1084,45 +1089,33 @@ bool AcceptToMemoryPool(CTxMemPool& pool, TxValidationState &state, const CTrans return AcceptToMemoryPoolWithTime(chainparams, pool, state, tx, GetTime(), plTxnReplaced, bypass_limits, nAbsurdFee, test_accept); } -/** - * Return transaction in txOut, and if it was found inside a block, its hash is placed in hashBlock. - * If blockIndex is provided, the transaction is fetched from the corresponding block. - */ -bool GetTransaction(const uint256& hash, CTransactionRef& txOut, const Consensus::Params& consensusParams, uint256& hashBlock, const CBlockIndex* const block_index) +CTransactionRef GetTransaction(const CBlockIndex* const block_index, const CTxMemPool* const mempool, const uint256& hash, const Consensus::Params& consensusParams, uint256& hashBlock) { LOCK(cs_main); - if (!block_index) { - CTransactionRef ptx = mempool.get(hash); - if (ptx) { - txOut = ptx; - return true; - } - - if (g_txindex) { - return g_txindex->FindTx(hash, hashBlock, txOut); - } - } else { + if (block_index) { CBlock block; if (ReadBlockFromDisk(block, block_index, consensusParams)) { for (const auto& tx : block.vtx) { if (tx->GetHash() == hash) { - txOut = tx; hashBlock = block_index->GetBlockHash(); - return true; + return tx; } } } + return nullptr; } - - return false; + if (mempool) { + CTransactionRef ptx = mempool->get(hash); + if (ptx) return ptx; + } + if (g_txindex) { + CTransactionRef tx; + if (g_txindex->FindTx(hash, hashBlock, tx)) return tx; + } + return nullptr; } - - - - - ////////////////////////////////////////////////////////////////////////////// // // CBlock and CBlockIndex @@ -1206,8 +1199,8 @@ bool ReadRawBlockFromDisk(std::vector<uint8_t>& block, const FlatFilePos& pos, c if (memcmp(blk_start, message_start, CMessageHeader::MESSAGE_START_SIZE)) { return error("%s: Block magic mismatch for %s: %s versus expected %s", __func__, pos.ToString(), - HexStr(blk_start, blk_start + CMessageHeader::MESSAGE_START_SIZE), - HexStr(message_start, message_start + CMessageHeader::MESSAGE_START_SIZE)); + HexStr(blk_start), + HexStr(message_start)); } if (blk_size > MAX_SIZE) { @@ -1279,9 +1272,10 @@ void CChainState::InitCoinsDB( leveldb_name, cache_size_bytes, in_memory, should_wipe); } -void CChainState::InitCoinsCache() +void CChainState::InitCoinsCache(size_t cache_size_bytes) { assert(m_coins_views != nullptr); + m_coinstip_cache_size_bytes = cache_size_bytes; m_coins_views->InitCache(); } @@ -1314,12 +1308,6 @@ bool CChainState::IsInitialBlockDownload() const static CBlockIndex *pindexBestForkTip = nullptr, *pindexBestForkBase = nullptr; -BlockMap& BlockIndex() -{ - LOCK(::cs_main); - return g_chainman.m_blockman.m_block_index; -} - static void AlertNotify(const std::string& strMessage) { uiInterface.NotifyAlertChanged(); @@ -1423,12 +1411,12 @@ void static InvalidChainFound(CBlockIndex* pindexNew) EXCLUSIVE_LOCKS_REQUIRED(c pindexBestHeader = ::ChainActive().Tip(); } - LogPrintf("%s: invalid block=%s height=%d log2_work=%.8g date=%s\n", __func__, + LogPrintf("%s: invalid block=%s height=%d log2_work=%f date=%s\n", __func__, pindexNew->GetBlockHash().ToString(), pindexNew->nHeight, log(pindexNew->nChainWork.getdouble())/log(2.0), FormatISO8601DateTime(pindexNew->GetBlockTime())); CBlockIndex *tip = ::ChainActive().Tip(); assert (tip); - LogPrintf("%s: current best=%s height=%d log2_work=%.8g date=%s\n", __func__, + LogPrintf("%s: current best=%s height=%d log2_work=%f date=%s\n", __func__, tip->GetBlockHash().ToString(), ::ChainActive().Height(), log(tip->nChainWork.getdouble())/log(2.0), FormatISO8601DateTime(tip->GetBlockTime())); CheckForkWarningConditions(); @@ -1481,14 +1469,21 @@ int GetSpendHeight(const CCoinsViewCache& inputs) } -static CuckooCache::cache<uint256, SignatureCacheHasher> scriptExecutionCache; -static uint256 scriptExecutionCacheNonce(GetRandHash()); +static CuckooCache::cache<uint256, SignatureCacheHasher> g_scriptExecutionCache; +static CSHA256 g_scriptExecutionCacheHasher; void InitScriptExecutionCache() { + // Setup the salted hasher + uint256 nonce = GetRandHash(); + // We want the nonce to be 64 bytes long to force the hasher to process + // this chunk, which makes later hash computations more efficient. We + // just write our 32-byte entropy twice to fill the 64 bytes. + g_scriptExecutionCacheHasher.Write(nonce.begin(), 32); + g_scriptExecutionCacheHasher.Write(nonce.begin(), 32); // nMaxCacheSize is unsigned. If -maxsigcachesize is set to zero, // setup_bytes creates the minimum possible cache (2 elements). size_t nMaxCacheSize = std::min(std::max((int64_t)0, gArgs.GetArg("-maxsigcachesize", DEFAULT_MAX_SIG_CACHE_SIZE) / 2), MAX_MAX_SIG_CACHE_SIZE) * ((size_t) 1 << 20); - size_t nElems = scriptExecutionCache.setup_bytes(nMaxCacheSize); + size_t nElems = g_scriptExecutionCache.setup_bytes(nMaxCacheSize); LogPrintf("Using %zu MiB out of %zu/2 requested for script execution cache, able to store %zu elements\n", (nElems*sizeof(uint256)) >>20, (nMaxCacheSize*2)>>20, nElems); } @@ -1526,12 +1521,10 @@ bool CheckInputScripts(const CTransaction& tx, TxValidationState &state, const C // properly commits to the scriptPubKey in the inputs view of that // transaction). uint256 hashCacheEntry; - // We only use the first 19 bytes of nonce to avoid a second SHA - // round - giving us 19 + 32 + 4 = 55 bytes (+ 8 + 1 = 64) - static_assert(55 - sizeof(flags) - 32 >= 128/8, "Want at least 128 bits of nonce for script execution cache"); - CSHA256().Write(scriptExecutionCacheNonce.begin(), 55 - sizeof(flags) - 32).Write(tx.GetWitnessHash().begin(), 32).Write((unsigned char*)&flags, sizeof(flags)).Finalize(hashCacheEntry.begin()); + CSHA256 hasher = g_scriptExecutionCacheHasher; + hasher.Write(tx.GetWitnessHash().begin(), 32).Write((unsigned char*)&flags, sizeof(flags)).Finalize(hashCacheEntry.begin()); AssertLockHeld(cs_main); //TODO: Remove this requirement by making CuckooCache not require external locks - if (scriptExecutionCache.contains(hashCacheEntry, !cacheFullScriptStore)) { + if (g_scriptExecutionCache.contains(hashCacheEntry, !cacheFullScriptStore)) { return true; } @@ -1586,7 +1579,7 @@ bool CheckInputScripts(const CTransaction& tx, TxValidationState &state, const C if (cacheFullScriptStore && !pvChecks) { // We executed all of the provided scripts, and were told to // cache the result. Do so now. - scriptExecutionCache.insert(hashCacheEntry); + g_scriptExecutionCache.insert(hashCacheEntry); } return true; @@ -1651,23 +1644,21 @@ bool UndoReadFromDisk(CBlockUndo& blockundo, const CBlockIndex* pindex) } /** Abort with a message */ -// TODO: AbortNode() should take bilingual_str userMessage parameter. -static bool AbortNode(const std::string& strMessage, const std::string& userMessage = "", unsigned int prefix = 0) +static bool AbortNode(const std::string& strMessage, bilingual_str user_message = bilingual_str()) { - SetMiscWarning(strMessage); + SetMiscWarning(Untranslated(strMessage)); LogPrintf("*** %s\n", strMessage); - if (!userMessage.empty()) { - uiInterface.ThreadSafeMessageBox(Untranslated(userMessage), "", CClientUIInterface::MSG_ERROR | prefix); - } else { - uiInterface.ThreadSafeMessageBox(_("Error: A fatal internal error occurred, see debug.log for details"), "", CClientUIInterface::MSG_ERROR | CClientUIInterface::MSG_NOPREFIX); + if (user_message.empty()) { + user_message = _("A fatal internal error occurred, see debug.log for details"); } + AbortError(user_message); StartShutdown(); return false; } -static bool AbortNode(BlockValidationState& state, const std::string& strMessage, const std::string& userMessage = "", unsigned int prefix = 0) +static bool AbortNode(BlockValidationState& state, const std::string& strMessage, const bilingual_str& userMessage = bilingual_str()) { - AbortNode(strMessage, userMessage, prefix); + AbortNode(strMessage, userMessage); return state.Error(strMessage); } @@ -1765,19 +1756,24 @@ DisconnectResult CChainState::DisconnectBlock(const CBlock& block, const CBlockI return fClean ? DISCONNECT_OK : DISCONNECT_UNCLEAN; } -void static FlushBlockFile(bool fFinalize = false) +static void FlushUndoFile(int block_file, bool finalize = false) { - LOCK(cs_LastBlockFile); + FlatFilePos undo_pos_old(block_file, vinfoBlockFile[block_file].nUndoSize); + if (!UndoFileSeq().Flush(undo_pos_old, finalize)) { + AbortNode("Flushing undo file to disk failed. This is likely the result of an I/O error."); + } +} +static void FlushBlockFile(bool fFinalize = false, bool finalize_undo = false) +{ + LOCK(cs_LastBlockFile); FlatFilePos block_pos_old(nLastBlockFile, vinfoBlockFile[nLastBlockFile].nSize); - FlatFilePos undo_pos_old(nLastBlockFile, vinfoBlockFile[nLastBlockFile].nUndoSize); - - bool status = true; - status &= BlockFileSeq().Flush(block_pos_old, fFinalize); - status &= UndoFileSeq().Flush(undo_pos_old, fFinalize); - if (!status) { + if (!BlockFileSeq().Flush(block_pos_old, fFinalize)) { AbortNode("Flushing block file to disk failed. This is likely the result of an I/O error."); } + // we do not always flush the undo file, as the chain tip may be lagging behind the incoming blocks, + // e.g. during IBD or a sync after a node going offline + if (!fFinalize || finalize_undo) FlushUndoFile(nLastBlockFile, finalize_undo); } static bool FindUndoPos(BlockValidationState &state, int nFile, FlatFilePos &pos, unsigned int nAddSize); @@ -1791,6 +1787,14 @@ static bool WriteUndoDataForBlock(const CBlockUndo& blockundo, BlockValidationSt return error("ConnectBlock(): FindUndoPos failed"); if (!UndoWriteToDisk(blockundo, _pos, pindex->pprev->GetBlockHash(), chainparams.MessageStart())) return AbortNode(state, "Failed to write undo data"); + // rev files are written in block height order, whereas blk files are written as blocks come in (often out of order) + // we want to flush the rev (undo) file once we've written the last block, which is indicated by the last height + // in the block file info as below; note that this does not catch the case where the undo writes are keeping up + // with the block writes (usually when a synced up node is getting newly mined blocks) -- this case is caught in + // the FindBlockPos function + if (_pos.nFile < nLastBlockFile && static_cast<uint32_t>(pindex->nHeight) == vinfoBlockFile[_pos.nFile].nHeightLast) { + FlushUndoFile(_pos.nFile, true); + } // update nUndoPos in block index pindex->nUndoPos = _pos.nPos; @@ -2147,7 +2151,7 @@ bool CChainState::ConnectBlock(const CBlock& block, BlockValidationState& state, prevheights[j] = view.AccessCoin(tx.vin[j].prevout).nHeight; } - if (!SequenceLocks(tx, nLockTimeFlags, &prevheights, *pindex)) { + if (!SequenceLocks(tx, nLockTimeFlags, prevheights, *pindex)) { LogPrintf("ERROR: %s: contains a non-BIP68-final transaction\n", __func__); return state.Invalid(BlockValidationResult::BLOCK_CONSENSUS, "bad-txns-nonfinal"); } @@ -2224,20 +2228,20 @@ bool CChainState::ConnectBlock(const CBlock& block, BlockValidationState& state, return true; } -CoinsCacheSizeState CChainState::GetCoinsCacheSizeState(const CTxMemPool& tx_pool) +CoinsCacheSizeState CChainState::GetCoinsCacheSizeState(const CTxMemPool* tx_pool) { return this->GetCoinsCacheSizeState( tx_pool, - nCoinCacheUsage, + m_coinstip_cache_size_bytes, gArgs.GetArg("-maxmempool", DEFAULT_MAX_MEMPOOL_SIZE) * 1000000); } CoinsCacheSizeState CChainState::GetCoinsCacheSizeState( - const CTxMemPool& tx_pool, + const CTxMemPool* tx_pool, size_t max_coins_cache_size_bytes, size_t max_mempool_size_bytes) { - int64_t nMempoolUsage = tx_pool.DynamicMemoryUsage(); + const int64_t nMempoolUsage = tx_pool ? tx_pool->DynamicMemoryUsage() : 0; int64_t cacheSize = CoinsTip().DynamicMemoryUsage(); int64_t nTotalSpace = max_coins_cache_size_bytes + std::max<int64_t>(max_mempool_size_bytes - nMempoolUsage, 0); @@ -2264,8 +2268,8 @@ bool CChainState::FlushStateToDisk( { LOCK(cs_main); assert(this->CanFlushToDisk()); - static int64_t nLastWrite = 0; - static int64_t nLastFlush = 0; + static std::chrono::microseconds nLastWrite{0}; + static std::chrono::microseconds nLastFlush{0}; std::set<int> setFilesToPrune; bool full_flush_completed = false; @@ -2276,7 +2280,7 @@ bool CChainState::FlushStateToDisk( { bool fFlushForPrune = false; bool fDoFullFlush = false; - CoinsCacheSizeState cache_state = GetCoinsCacheSizeState(::mempool); + CoinsCacheSizeState cache_state = GetCoinsCacheSizeState(&::mempool); LOCK(cs_LastBlockFile); if (fPruneMode && (fCheckForPruning || nManualPruneHeight > 0) && !fReindex) { if (nManualPruneHeight > 0) { @@ -2297,12 +2301,12 @@ bool CChainState::FlushStateToDisk( } } } - int64_t nNow = GetTimeMicros(); + const auto nNow = GetTime<std::chrono::microseconds>(); // Avoid writing/flushing immediately after startup. - if (nLastWrite == 0) { + if (nLastWrite.count() == 0) { nLastWrite = nNow; } - if (nLastFlush == 0) { + if (nLastFlush.count() == 0) { nLastFlush = nNow; } // The cache is large and we're within 10% and 10 MiB of the limit, but we have time now (not in the middle of a block processing). @@ -2310,16 +2314,16 @@ bool CChainState::FlushStateToDisk( // The cache is over the limit, we have to write now. bool fCacheCritical = mode == FlushStateMode::IF_NEEDED && cache_state >= CoinsCacheSizeState::CRITICAL; // It's been a while since we wrote the block index to disk. Do this frequently, so we don't need to redownload after a crash. - bool fPeriodicWrite = mode == FlushStateMode::PERIODIC && nNow > nLastWrite + (int64_t)DATABASE_WRITE_INTERVAL * 1000000; + bool fPeriodicWrite = mode == FlushStateMode::PERIODIC && nNow > nLastWrite + DATABASE_WRITE_INTERVAL; // It's been very long since we flushed the cache. Do this infrequently, to optimize cache usage. - bool fPeriodicFlush = mode == FlushStateMode::PERIODIC && nNow > nLastFlush + (int64_t)DATABASE_FLUSH_INTERVAL * 1000000; + bool fPeriodicFlush = mode == FlushStateMode::PERIODIC && nNow > nLastFlush + DATABASE_FLUSH_INTERVAL; // Combine all conditions that result in a full cache flush. fDoFullFlush = (mode == FlushStateMode::ALWAYS) || fCacheLarge || fCacheCritical || fPeriodicFlush || fFlushForPrune; // Write blocks and block index to disk. if (fDoFullFlush || fPeriodicWrite) { // Depend on nMinDiskSpace to ensure we can write block index if (!CheckDiskSpace(GetBlocksDir())) { - return AbortNode(state, "Disk space is too low!", _("Error: Disk space is too low!").translated, CClientUIInterface::MSG_NOPREFIX); + return AbortNode(state, "Disk space is too low!", _("Disk space is too low!")); } { LOG_TIME_MILLIS_WITH_CATEGORY("write block and undo data to disk", BCLog::BENCH); @@ -2367,7 +2371,7 @@ bool CChainState::FlushStateToDisk( // an overestimation, as most will delete an existing entry or // overwrite one. Still, use a conservative safety factor of 2. if (!CheckDiskSpace(GetDataDir(), 48 * 2 * 2 * CoinsTip().GetCacheSize())) { - return AbortNode(state, "Disk space is too low!", _("Error: Disk space is too low!").translated, CClientUIInterface::MSG_NOPREFIX); + return AbortNode(state, "Disk space is too low!", _("Disk space is too low!")); } // Flush the chainstate (which may refer to block index entries). if (!CoinsTip().Flush()) @@ -2404,20 +2408,20 @@ void CChainState::PruneAndFlush() { } } -static void DoWarning(const std::string& strWarning) +static void DoWarning(const bilingual_str& warning) { static bool fWarned = false; - SetMiscWarning(strWarning); + SetMiscWarning(warning); if (!fWarned) { - AlertNotify(strWarning); + AlertNotify(warning.original); fWarned = true; } } /** Private helper function that concatenates warning messages. */ -static void AppendWarning(std::string& res, const std::string& warn) +static void AppendWarning(bilingual_str& res, const bilingual_str& warn) { - if (!res.empty()) res += ", "; + if (!res.empty()) res += Untranslated(", "); res += warn; } @@ -2434,7 +2438,7 @@ void static UpdateTip(const CBlockIndex* pindexNew, const CChainParams& chainPar g_best_block_cv.notify_all(); } - std::string warningMessages; + bilingual_str warning_messages; if (!::ChainstateActive().IsInitialBlockDownload()) { int nUpgraded = 0; @@ -2443,11 +2447,11 @@ void static UpdateTip(const CBlockIndex* pindexNew, const CChainParams& chainPar WarningBitsConditionChecker checker(bit); ThresholdState state = checker.GetStateFor(pindex, chainParams.GetConsensus(), warningcache[bit]); if (state == ThresholdState::ACTIVE || state == ThresholdState::LOCKED_IN) { - const std::string strWarning = strprintf(_("Warning: unknown new rules activated (versionbit %i)").translated, bit); + const bilingual_str warning = strprintf(_("Warning: unknown new rules activated (versionbit %i)"), bit); if (state == ThresholdState::ACTIVE) { - DoWarning(strWarning); + DoWarning(warning); } else { - AppendWarning(warningMessages, strWarning); + AppendWarning(warning_messages, warning); } } } @@ -2460,14 +2464,14 @@ void static UpdateTip(const CBlockIndex* pindexNew, const CChainParams& chainPar pindex = pindex->pprev; } if (nUpgraded > 0) - AppendWarning(warningMessages, strprintf(_("%d of last 100 blocks have unexpected version").translated, nUpgraded)); + AppendWarning(warning_messages, strprintf(_("%d of last 100 blocks have unexpected version"), nUpgraded)); } - LogPrintf("%s: new best=%s height=%d version=0x%08x log2_work=%.8g tx=%lu date='%s' progress=%f cache=%.1fMiB(%utxo)%s\n", __func__, + LogPrintf("%s: new best=%s height=%d version=0x%08x log2_work=%f tx=%lu date='%s' progress=%f cache=%.1fMiB(%utxo)%s\n", __func__, pindexNew->GetBlockHash().ToString(), pindexNew->nHeight, pindexNew->nVersion, log(pindexNew->nChainWork.getdouble())/log(2.0), (unsigned long)pindexNew->nChainTx, FormatISO8601DateTime(pindexNew->GetBlockTime()), GuessVerificationProgress(chainParams.TxData(), pindexNew), ::ChainstateActive().CoinsTip().DynamicMemoryUsage() * (1.0 / (1<<20)), ::ChainstateActive().CoinsTip().GetCacheSize(), - !warningMessages.empty() ? strprintf(" warning='%s'", warningMessages) : ""); + !warning_messages.empty() ? strprintf(" warning='%s'", warning_messages.original) : ""); } @@ -2854,8 +2858,6 @@ bool CChainState::ActivateBestChain(BlockValidationState &state, const CChainPar CBlockIndex *pindexNewTip = nullptr; int nStopAtHeight = gArgs.GetArg("-stopatheight", DEFAULT_STOPATHEIGHT); do { - boost::this_thread::interruption_point(); - // Block until the validation queue drains. This should largely // never happen in normal operation, however may happen during // reindex, causing memory blowup if we run too far ahead. @@ -2924,8 +2926,7 @@ bool CChainState::ActivateBestChain(BlockValidationState &state, const CChainPar // never shutdown before connecting the genesis block during LoadChainTip(). Previously this // caused an assert() failure during shutdown in such cases as the UTXO DB flushing checks // that the best block hash is non-null. - if (ShutdownRequested()) - break; + if (ShutdownRequested()) break; } while (pindexNewTip != pindexMostWork); CheckBlockIndex(chainparams.GetConsensus()); @@ -3243,8 +3244,13 @@ static bool FindBlockPos(FlatFilePos &pos, unsigned int nAddSize, unsigned int n vinfoBlockFile.resize(nFile + 1); } + bool finalize_undo = false; if (!fKnown) { while (vinfoBlockFile[nFile].nSize + nAddSize >= MAX_BLOCKFILE_SIZE) { + // when the undo file is keeping up with the block file, we want to flush it explicitly + // when it is lagging behind (more blocks arrive than are being connected), we let the + // undo block write case handle it + finalize_undo = (vinfoBlockFile[nFile].nHeightLast == (unsigned int)ChainActive().Tip()->nHeight); nFile++; if (vinfoBlockFile.size() <= nFile) { vinfoBlockFile.resize(nFile + 1); @@ -3258,7 +3264,7 @@ static bool FindBlockPos(FlatFilePos &pos, unsigned int nAddSize, unsigned int n if (!fKnown) { LogPrintf("Leaving block file %i: %s\n", nLastBlockFile, vinfoBlockFile[nLastBlockFile].ToString()); } - FlushBlockFile(!fKnown); + FlushBlockFile(!fKnown, finalize_undo); nLastBlockFile = nFile; } @@ -3272,7 +3278,7 @@ static bool FindBlockPos(FlatFilePos &pos, unsigned int nAddSize, unsigned int n bool out_of_space; size_t bytes_allocated = BlockFileSeq().Allocate(pos, nAddSize, out_of_space); if (out_of_space) { - return AbortNode("Disk space is too low!", _("Error: Disk space is too low!").translated, CClientUIInterface::MSG_NOPREFIX); + return AbortNode("Disk space is too low!", _("Disk space is too low!")); } if (bytes_allocated != 0 && fPruneMode) { fCheckForPruning = true; @@ -3296,7 +3302,7 @@ static bool FindUndoPos(BlockValidationState &state, int nFile, FlatFilePos &pos bool out_of_space; size_t bytes_allocated = UndoFileSeq().Allocate(pos, nAddSize, out_of_space); if (out_of_space) { - return AbortNode(state, "Disk space is too low!", _("Error: Disk space is too low!").translated, CClientUIInterface::MSG_NOPREFIX); + return AbortNode(state, "Disk space is too low!", _("Disk space is too low!")); } if (bytes_allocated != 0 && fPruneMode) { fCheckForPruning = true; @@ -3429,7 +3435,7 @@ std::vector<unsigned char> GenerateCoinbaseCommitment(CBlock& block, const CBloc if (consensusParams.SegwitHeight != std::numeric_limits<int>::max()) { if (commitpos == -1) { uint256 witnessroot = BlockWitnessMerkleRoot(block, nullptr); - CHash256().Write(witnessroot.begin(), 32).Write(ret.data(), 32).Finalize(witnessroot.begin()); + CHash256().Write(witnessroot).Write(ret).Finalize(witnessroot); CTxOut out; out.nValue = 0; out.scriptPubKey.resize(MINIMUM_WITNESS_COMMITMENT); @@ -3574,7 +3580,7 @@ static bool ContextualCheckBlock(const CBlock& block, BlockValidationState& stat if (block.vtx[0]->vin[0].scriptWitness.stack.size() != 1 || block.vtx[0]->vin[0].scriptWitness.stack[0].size() != 32) { return state.Invalid(BlockValidationResult::BLOCK_MUTATED, "bad-witness-nonce-size", strprintf("%s : invalid witness reserved value size", __func__)); } - CHash256().Write(hashWitness.begin(), 32).Write(&block.vtx[0]->vin[0].scriptWitness.stack[0][0], 32).Finalize(hashWitness.begin()); + CHash256().Write(hashWitness).Write(block.vtx[0]->vin[0].scriptWitness.stack[0]).Finalize(hashWitness); if (memcmp(hashWitness.begin(), &block.vtx[0]->vout[commitpos].scriptPubKey[6], 32)) { return state.Invalid(BlockValidationResult::BLOCK_MUTATED, "bad-witness-merkle-match", strprintf("%s : witness merkle commitment mismatch", __func__)); } @@ -3624,8 +3630,10 @@ bool BlockManager::AcceptBlockHeader(const CBlockHeader& block, BlockValidationS return true; } - if (!CheckBlockHeader(block, state, chainparams.GetConsensus())) - return error("%s: Consensus::CheckBlockHeader: %s, %s", __func__, hash.ToString(), state.ToString()); + if (!CheckBlockHeader(block, state, chainparams.GetConsensus())) { + LogPrint(BCLog::VALIDATION, "%s: Consensus::CheckBlockHeader: %s, %s\n", __func__, hash.ToString(), state.ToString()); + return false; + } // Get prev block index CBlockIndex* pindexPrev = nullptr; @@ -4267,7 +4275,6 @@ bool CVerifyDB::VerifyDB(const CChainParams& chainparams, CCoinsView *coinsview, int reportDone = 0; LogPrintf("[0%%]..."); /* Continued */ for (pindex = ::ChainActive().Tip(); pindex && pindex->pprev; pindex = pindex->pprev) { - boost::this_thread::interruption_point(); const int percentageDone = std::max(1, std::min(99, (int)(((double)(::ChainActive().Height() - pindex->nHeight)) / (double)nCheckDepth * (nCheckLevel >= 4 ? 50 : 100)))); if (reportDone < percentageDone/10) { // report every 10% step @@ -4300,7 +4307,7 @@ bool CVerifyDB::VerifyDB(const CChainParams& chainparams, CCoinsView *coinsview, } } // check level 3: check for inconsistencies during memory-only disconnect of tip blocks - if (nCheckLevel >= 3 && (coins.DynamicMemoryUsage() + ::ChainstateActive().CoinsTip().DynamicMemoryUsage()) <= nCoinCacheUsage) { + if (nCheckLevel >= 3 && (coins.DynamicMemoryUsage() + ::ChainstateActive().CoinsTip().DynamicMemoryUsage()) <= ::ChainstateActive().m_coinstip_cache_size_bytes) { assert(coins.GetBestBlock() == pindex->GetBlockHash()); DisconnectResult res = ::ChainstateActive().DisconnectBlock(block, pindex, coins); if (res == DISCONNECT_FAILED) { @@ -4313,8 +4320,7 @@ bool CVerifyDB::VerifyDB(const CChainParams& chainparams, CCoinsView *coinsview, nGoodTransactions += block.vtx.size(); } } - if (ShutdownRequested()) - return true; + if (ShutdownRequested()) return true; } if (pindexFailure) return error("VerifyDB(): *** coin database inconsistencies found (last %i blocks, %i good transactions before that)\n", ::ChainActive().Height() - pindexFailure->nHeight + 1, nGoodTransactions); @@ -4325,7 +4331,6 @@ bool CVerifyDB::VerifyDB(const CChainParams& chainparams, CCoinsView *coinsview, // check level 4: try reconnecting blocks if (nCheckLevel >= 4) { while (pindex != ::ChainActive().Tip()) { - boost::this_thread::interruption_point(); const int percentageDone = std::max(1, std::min(99, 100 - (int)(((double)(::ChainActive().Height() - pindex->nHeight)) / (double)nCheckDepth * 50))); if (reportDone < percentageDone/10) { // report every 10% step @@ -4339,6 +4344,7 @@ bool CVerifyDB::VerifyDB(const CChainParams& chainparams, CCoinsView *coinsview, return error("VerifyDB(): *** ReadBlockFromDisk failed at %d, hash=%s", pindex->nHeight, pindex->GetBlockHash().ToString()); if (!::ChainstateActive().ConnectBlock(block, state, pindex, coins, chainparams)) return error("VerifyDB(): *** found unconnectable block at %d, hash=%s (%s)", pindex->nHeight, pindex->GetBlockHash().ToString(), state.ToString()); + if (ShutdownRequested()) return true; } } @@ -4582,13 +4588,13 @@ void CChainState::UnloadBlockIndex() { // May NOT be used after any connections are up as much // of the peer-processing logic assumes a consistent // block index state -void UnloadBlockIndex() +void UnloadBlockIndex(CTxMemPool* mempool) { LOCK(cs_main); g_chainman.Unload(); pindexBestInvalid = nullptr; pindexBestHeader = nullptr; - mempool.clear(); + if (mempool) mempool->clear(); vinfoBlockFile.clear(); nLastBlockFile = 0; setDirtyBlockIndex.clear(); @@ -4693,7 +4699,6 @@ void LoadExternalBlockFile(const CChainParams& chainparams, FILE* fileIn, FlatFi if (dbp) dbp->nPos = nBlockPos; blkdat.SetLimit(nBlockPos + nSize); - blkdat.SetPos(nBlockPos); std::shared_ptr<CBlock> pblock = std::make_shared<CBlock>(); CBlock& block = *pblock; blkdat >> block; @@ -4965,6 +4970,39 @@ std::string CChainState::ToString() tip ? tip->nHeight : -1, tip ? tip->GetBlockHash().ToString() : "null"); } +bool CChainState::ResizeCoinsCaches(size_t coinstip_size, size_t coinsdb_size) +{ + if (coinstip_size == m_coinstip_cache_size_bytes && + coinsdb_size == m_coinsdb_cache_size_bytes) { + // Cache sizes are unchanged, no need to continue. + return true; + } + size_t old_coinstip_size = m_coinstip_cache_size_bytes; + m_coinstip_cache_size_bytes = coinstip_size; + m_coinsdb_cache_size_bytes = coinsdb_size; + CoinsDB().ResizeCache(coinsdb_size); + + LogPrintf("[%s] resized coinsdb cache to %.1f MiB\n", + this->ToString(), coinsdb_size * (1.0 / 1024 / 1024)); + LogPrintf("[%s] resized coinstip cache to %.1f MiB\n", + this->ToString(), coinstip_size * (1.0 / 1024 / 1024)); + + BlockValidationState state; + const CChainParams& chainparams = Params(); + + bool ret; + + if (coinstip_size > old_coinstip_size) { + // Likely no need to flush if cache sizes have grown. + ret = FlushStateToDisk(chainparams, state, FlushStateMode::IF_NEEDED); + } else { + // Otherwise, flush state to disk and deallocate the in-memory coins map. + ret = FlushStateToDisk(chainparams, state, FlushStateMode::ALWAYS); + CoinsTip().ReallocateCache(); + } + return ret; +} + std::string CBlockFileInfo::ToString() const { return strprintf("CBlockFileInfo(blocks=%u, size=%u, heights=%u...%u, time=%s...%s)", nBlocks, nSize, nHeightFirst, nHeightLast, FormatISO8601Date(nTimeFirst), FormatISO8601Date(nTimeLast)); @@ -5068,19 +5106,22 @@ bool LoadMempool(CTxMemPool& pool) } // TODO: remove this try except in v0.22 + std::map<uint256, uint256> unbroadcast_txids; try { - std::set<uint256> unbroadcast_txids; file >> unbroadcast_txids; unbroadcast = unbroadcast_txids.size(); - - for (const auto& txid : unbroadcast_txids) { - pool.AddUnbroadcastTx(txid); - } } catch (const std::exception&) { // mempool.dat files created prior to v0.21 will not have an // unbroadcast set. No need to log a failure if parsing fails here. } - + for (const auto& elem : unbroadcast_txids) { + // Don't add unbroadcast transactions that didn't get back into the + // mempool. + const CTransactionRef& added_tx = pool.get(elem.first); + if (added_tx != nullptr) { + pool.AddUnbroadcastTx(elem.first, added_tx->GetWitnessHash()); + } + } } catch (const std::exception& e) { LogPrintf("Failed to deserialize mempool data on disk: %s. Continuing anyway.\n", e.what()); return false; @@ -5096,7 +5137,7 @@ bool DumpMempool(const CTxMemPool& pool) std::map<uint256, CAmount> mapDeltas; std::vector<TxMempoolInfo> vinfo; - std::set<uint256> unbroadcast_txids; + std::map<uint256, uint256> unbroadcast_txids; static Mutex dump_mutex; LOCK(dump_mutex); @@ -5228,10 +5269,10 @@ CChainState& ChainstateManager::InitializeChainstate(const uint256& snapshot_blo return *to_modify; } -CChain& ChainstateManager::ActiveChain() const +CChainState& ChainstateManager::ActiveChainstate() const { assert(m_active_chainstate); - return m_active_chainstate->m_chain; + return *m_active_chainstate; } bool ChainstateManager::IsSnapshotActive() const @@ -5270,3 +5311,33 @@ void ChainstateManager::Reset() m_active_chainstate = nullptr; m_snapshot_validated = false; } + +void ChainstateManager::MaybeRebalanceCaches() +{ + if (m_ibd_chainstate && !m_snapshot_chainstate) { + LogPrintf("[snapshot] allocating all cache to the IBD chainstate\n"); + // Allocate everything to the IBD chainstate. + m_ibd_chainstate->ResizeCoinsCaches(m_total_coinstip_cache, m_total_coinsdb_cache); + } + else if (m_snapshot_chainstate && !m_ibd_chainstate) { + LogPrintf("[snapshot] allocating all cache to the snapshot chainstate\n"); + // Allocate everything to the snapshot chainstate. + m_snapshot_chainstate->ResizeCoinsCaches(m_total_coinstip_cache, m_total_coinsdb_cache); + } + else if (m_ibd_chainstate && m_snapshot_chainstate) { + // If both chainstates exist, determine who needs more cache based on IBD status. + // + // Note: shrink caches first so that we don't inadvertently overwhelm available memory. + if (m_snapshot_chainstate->IsInitialBlockDownload()) { + m_ibd_chainstate->ResizeCoinsCaches( + m_total_coinstip_cache * 0.05, m_total_coinsdb_cache * 0.05); + m_snapshot_chainstate->ResizeCoinsCaches( + m_total_coinstip_cache * 0.95, m_total_coinsdb_cache * 0.95); + } else { + m_snapshot_chainstate->ResizeCoinsCaches( + m_total_coinstip_cache * 0.05, m_total_coinsdb_cache * 0.05); + m_ibd_chainstate->ResizeCoinsCaches( + m_total_coinstip_cache * 0.95, m_total_coinsdb_cache * 0.95); + } + } +} diff --git a/src/validation.h b/src/validation.h index 8112e38704..534162d64a 100644 --- a/src/validation.h +++ b/src/validation.h @@ -29,6 +29,7 @@ #include <memory> #include <set> #include <stdint.h> +#include <string> #include <utility> #include <vector> @@ -73,7 +74,6 @@ static const int64_t DEFAULT_MAX_TIP_AGE = 24 * 60 * 60; static const bool DEFAULT_CHECKPOINTS_ENABLED = true; static const bool DEFAULT_TXINDEX = false; static const char* const DEFAULT_BLOCKFILTERINDEX = "0"; -static const unsigned int DEFAULT_BANSCORE_THRESHOLD = 100; /** Default for -persistmempool */ static const bool DEFAULT_PERSIST_MEMPOOL = true; /** Default for using fee filter */ @@ -127,7 +127,6 @@ extern bool g_parallel_script_checks; extern bool fRequireStandard; extern bool fCheckBlockIndex; extern bool fCheckpointsEnabled; -extern size_t nCoinCacheUsage; /** A fee rate smaller than this is considered zero fee (for relaying, mining and transaction creation) */ extern CFeeRate minRelayTxFee; /** If the tip is older than this (in seconds), the node is considered to be in initial block download. */ @@ -149,6 +148,8 @@ extern bool fHavePruned; extern bool fPruneMode; /** Number of MiB of block files that we're trying to stay below. */ extern uint64_t nPruneTarget; +/** Documentation for argument 'checklevel'. */ +extern const std::vector<std::string> CHECKLEVEL_DOC; /** Open a block file (blk?????.dat) */ FILE* OpenBlockFile(const FlatFilePos &pos, bool fReadOnly = false); @@ -159,11 +160,22 @@ void LoadExternalBlockFile(const CChainParams& chainparams, FILE* fileIn, FlatFi /** Ensures we have a genesis block in the block tree, possibly writing one to disk. */ bool LoadGenesisBlock(const CChainParams& chainparams); /** Unload database information */ -void UnloadBlockIndex(); +void UnloadBlockIndex(CTxMemPool* mempool); /** Run an instance of the script checking thread */ void ThreadScriptCheck(int worker_num); -/** Retrieve a transaction (from memory pool, or from disk, if possible) */ -bool GetTransaction(const uint256& hash, CTransactionRef& tx, const Consensus::Params& params, uint256& hashBlock, const CBlockIndex* const blockIndex = nullptr); +/** + * Return transaction from the block at block_index. + * If block_index is not provided, fall back to mempool. + * If mempool is not provided or the tx couldn't be found in mempool, fall back to g_txindex. + * + * @param[in] block_index The block to read from disk, or nullptr + * @param[in] mempool If block_index is not provided, look in the mempool, if provided + * @param[in] hash The txid + * @param[in] consensusParams The params + * @param[out] hashBlock The hash of block_index, if the tx was found via block_index + * @returns The tx if found, otherwise nullptr + */ +CTransactionRef GetTransaction(const CBlockIndex* const block_index, const CTxMemPool* const mempool, const uint256& hash, const Consensus::Params& consensusParams, uint256& hashBlock); /** * Find the best known block, and make it the tip of the block chain * @@ -519,7 +531,7 @@ public: //! Initialize the in-memory coins cache (to be done after the health of the on-disk database //! is verified). - void InitCoinsCache() EXCLUSIVE_LOCKS_REQUIRED(::cs_main); + void InitCoinsCache(size_t cache_size_bytes) EXCLUSIVE_LOCKS_REQUIRED(::cs_main); //! @returns whether or not the CoinsViews object has been fully initialized and we can //! safely flush this object to disk. @@ -568,6 +580,17 @@ public: //! Destructs all objects related to accessing the UTXO set. void ResetCoinsViews() { m_coins_views.reset(); } + //! The cache size of the on-disk coins view. + size_t m_coinsdb_cache_size_bytes{0}; + + //! The cache size of the in-memory coins view. + size_t m_coinstip_cache_size_bytes{0}; + + //! Resize the CoinsViews caches dynamically and flush state to disk. + //! @returns true unless an error occurred during the flush. + bool ResizeCoinsCaches(size_t coinstip_size, size_t coinsdb_size) + EXCLUSIVE_LOCKS_REQUIRED(::cs_main); + /** * Update the on-disk chain state. * The caches and indexes are flushed depending on the mode we're called with @@ -651,11 +674,11 @@ public: //! Dictates whether we need to flush the cache to disk or not. //! //! @return the state of the size of the coins cache. - CoinsCacheSizeState GetCoinsCacheSizeState(const CTxMemPool& tx_pool) + CoinsCacheSizeState GetCoinsCacheSizeState(const CTxMemPool* tx_pool) EXCLUSIVE_LOCKS_REQUIRED(::cs_main); CoinsCacheSizeState GetCoinsCacheSizeState( - const CTxMemPool& tx_pool, + const CTxMemPool* tx_pool, size_t max_coins_cache_size_bytes, size_t max_mempool_size_bytes) EXCLUSIVE_LOCKS_REQUIRED(::cs_main); @@ -784,6 +807,14 @@ public: //! chainstate to avoid duplicating block metadata. BlockManager m_blockman GUARDED_BY(::cs_main); + //! The total number of bytes available for us to use across all in-memory + //! coins caches. This will be split somehow across chainstates. + int64_t m_total_coinstip_cache{0}; + // + //! The total number of bytes available for us to use across all leveldb + //! coins databases. This will be split somehow across chainstates. + int64_t m_total_coinsdb_cache{0}; + //! Instantiate a new chainstate and assign it based upon whether it is //! from a snapshot. //! @@ -796,7 +827,8 @@ public: std::vector<CChainState*> GetAll(); //! The most-work chain. - CChain& ActiveChain() const; + CChainState& ActiveChainstate() const; + CChain& ActiveChain() const { return ActiveChainstate().m_chain; } int ActiveHeight() const { return ActiveChain().Height(); } CBlockIndex* ActiveTip() const { return ActiveChain().Tip(); } @@ -841,7 +873,7 @@ public: * validationinterface callback. * * @param[in] pblock The block we want to process. - * @param[in] fForceProcessing Process this block even if unrequested; used for non-network block sources and whitelisted peers. + * @param[in] fForceProcessing Process this block even if unrequested; used for non-network block sources. * @param[out] fNewBlock A boolean which is set to indicate if the block was first received via this call * @returns If the block was processed, independently of block validity */ @@ -871,20 +903,21 @@ public: //! Clear (deconstruct) chainstate data. void Reset(); + + //! Check to see if caches are out of balance and if so, call + //! ResizeCoinsCaches() as needed. + void MaybeRebalanceCaches() EXCLUSIVE_LOCKS_REQUIRED(::cs_main); }; /** DEPRECATED! Please use node.chainman instead. May only be used in validation.cpp internally */ extern ChainstateManager g_chainman GUARDED_BY(::cs_main); -/** @returns the most-work valid chainstate. */ +/** Please prefer the identical ChainstateManager::ActiveChainstate */ CChainState& ChainstateActive(); -/** @returns the most-work chain. */ +/** Please prefer the identical ChainstateManager::ActiveChain */ CChain& ChainActive(); -/** @returns the global block index map. */ -BlockMap& BlockIndex(); - /** Global variable that points to the active block tree (protected by cs_main) */ extern std::unique_ptr<CBlockTreeDB> pblocktree; diff --git a/src/validationinterface.cpp b/src/validationinterface.cpp index 9437f9c817..3dfbcc581c 100644 --- a/src/validationinterface.cpp +++ b/src/validationinterface.cpp @@ -199,22 +199,22 @@ void CMainSignals::UpdatedBlockTip(const CBlockIndex *pindexNew, const CBlockInd fInitialDownload); } -void CMainSignals::TransactionAddedToMempool(const CTransactionRef &ptx) { - auto event = [ptx, this] { - m_internals->Iterate([&](CValidationInterface& callbacks) { callbacks.TransactionAddedToMempool(ptx); }); +void CMainSignals::TransactionAddedToMempool(const CTransactionRef& tx) { + auto event = [tx, this] { + m_internals->Iterate([&](CValidationInterface& callbacks) { callbacks.TransactionAddedToMempool(tx); }); }; ENQUEUE_AND_LOG_EVENT(event, "%s: txid=%s wtxid=%s", __func__, - ptx->GetHash().ToString(), - ptx->GetWitnessHash().ToString()); + tx->GetHash().ToString(), + tx->GetWitnessHash().ToString()); } -void CMainSignals::TransactionRemovedFromMempool(const CTransactionRef &ptx) { - auto event = [ptx, this] { - m_internals->Iterate([&](CValidationInterface& callbacks) { callbacks.TransactionRemovedFromMempool(ptx); }); +void CMainSignals::TransactionRemovedFromMempool(const CTransactionRef& tx, MemPoolRemovalReason reason) { + auto event = [tx, reason, this] { + m_internals->Iterate([&](CValidationInterface& callbacks) { callbacks.TransactionRemovedFromMempool(tx, reason); }); }; ENQUEUE_AND_LOG_EVENT(event, "%s: txid=%s wtxid=%s", __func__, - ptx->GetHash().ToString(), - ptx->GetWitnessHash().ToString()); + tx->GetHash().ToString(), + tx->GetWitnessHash().ToString()); } void CMainSignals::BlockConnected(const std::shared_ptr<const CBlock> &pblock, const CBlockIndex *pindex) { diff --git a/src/validationinterface.h b/src/validationinterface.h index 9c23965bc1..e96f2883fc 100644 --- a/src/validationinterface.h +++ b/src/validationinterface.h @@ -21,6 +21,7 @@ class CConnman; class CValidationInterface; class uint256; class CScheduler; +enum class MemPoolRemovalReason; /** Register subscriber */ void RegisterValidationInterface(CValidationInterface* callbacks); @@ -96,7 +97,7 @@ protected: * * Called on a background thread. */ - virtual void TransactionAddedToMempool(const CTransactionRef &ptxn) {} + virtual void TransactionAddedToMempool(const CTransactionRef& tx) {} /** * Notifies listeners of a transaction leaving mempool. * @@ -129,7 +130,7 @@ protected: * * Called on a background thread. */ - virtual void TransactionRemovedFromMempool(const CTransactionRef &ptx) {} + virtual void TransactionRemovedFromMempool(const CTransactionRef& tx, MemPoolRemovalReason reason) {} /** * Notifies listeners of a block being connected. * Provides a vector of transactions evicted from the mempool as a result. @@ -196,8 +197,8 @@ public: void UpdatedBlockTip(const CBlockIndex *, const CBlockIndex *, bool fInitialDownload); - void TransactionAddedToMempool(const CTransactionRef &); - void TransactionRemovedFromMempool(const CTransactionRef &); + void TransactionAddedToMempool(const CTransactionRef&); + void TransactionRemovedFromMempool(const CTransactionRef&, MemPoolRemovalReason); void BlockConnected(const std::shared_ptr<const CBlock> &, const CBlockIndex *pindex); void BlockDisconnected(const std::shared_ptr<const CBlock> &, const CBlockIndex* pindex); void ChainStateFlushed(const CBlockLocator &); diff --git a/src/version.h b/src/version.h index d932b512d4..b5f379e1b8 100644 --- a/src/version.h +++ b/src/version.h @@ -9,20 +9,13 @@ * network protocol versioning */ -static const int PROTOCOL_VERSION = 70015; +static const int PROTOCOL_VERSION = 70016; //! initial proto version, to be increased after version/verack negotiation static const int INIT_PROTO_VERSION = 209; -//! In this version, 'getheaders' was introduced. -static const int GETHEADERS_VERSION = 31800; - //! disconnect from peers older than this proto version -static const int MIN_PEER_PROTO_VERSION = GETHEADERS_VERSION; - -//! nTime field added to CAddress, starting with this version; -//! if possible, avoid requesting addresses nodes older than this -static const int CADDR_TIME_VERSION = 31402; +static const int MIN_PEER_PROTO_VERSION = 31800; //! BIP 0031, pong message, is enabled for all versions AFTER this one static const int BIP0031_VERSION = 60000; @@ -42,4 +35,7 @@ static const int SHORT_IDS_BLOCKS_VERSION = 70014; //! not banning for invalid compact blocks starts with this version static const int INVALID_CB_NO_BAN_VERSION = 70015; +//! "wtxidrelay" command for wtxid-based relay starts with this version +static const int WTXID_RELAY_VERSION = 70016; + #endif // BITCOIN_VERSION_H diff --git a/src/wallet/bdb.cpp b/src/wallet/bdb.cpp new file mode 100644 index 0000000000..24eb2ee34c --- /dev/null +++ b/src/wallet/bdb.cpp @@ -0,0 +1,826 @@ +// Copyright (c) 2009-2010 Satoshi Nakamoto +// Copyright (c) 2009-2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <wallet/bdb.h> +#include <wallet/db.h> + +#include <util/strencodings.h> +#include <util/translation.h> + +#include <stdint.h> + +#ifndef WIN32 +#include <sys/stat.h> +#endif + +namespace { + +//! Make sure database has a unique fileid within the environment. If it +//! doesn't, throw an error. BDB caches do not work properly when more than one +//! open database has the same fileid (values written to one database may show +//! up in reads to other databases). +//! +//! BerkeleyDB generates unique fileids by default +//! (https://docs.oracle.com/cd/E17275_01/html/programmer_reference/program_copy.html), +//! so bitcoin should never create different databases with the same fileid, but +//! this error can be triggered if users manually copy database files. +void CheckUniqueFileid(const BerkeleyEnvironment& env, const std::string& filename, Db& db, WalletDatabaseFileId& fileid) +{ + if (env.IsMock()) return; + + int ret = db.get_mpf()->get_fileid(fileid.value); + if (ret != 0) { + throw std::runtime_error(strprintf("BerkeleyDatabase: Can't open database %s (get_fileid failed with %d)", filename, ret)); + } + + for (const auto& item : env.m_fileids) { + if (fileid == item.second && &fileid != &item.second) { + throw std::runtime_error(strprintf("BerkeleyDatabase: Can't open database %s (duplicates fileid %s from %s)", filename, + HexStr(item.second.value), item.first)); + } + } +} + +RecursiveMutex cs_db; +std::map<std::string, std::weak_ptr<BerkeleyEnvironment>> g_dbenvs GUARDED_BY(cs_db); //!< Map from directory name to db environment. +} // namespace + +bool WalletDatabaseFileId::operator==(const WalletDatabaseFileId& rhs) const +{ + return memcmp(value, &rhs.value, sizeof(value)) == 0; +} + +bool IsBDBWalletLoaded(const fs::path& wallet_path) +{ + fs::path env_directory; + std::string database_filename; + SplitWalletPath(wallet_path, env_directory, database_filename); + LOCK(cs_db); + auto env = g_dbenvs.find(env_directory.string()); + if (env == g_dbenvs.end()) return false; + auto database = env->second.lock(); + return database && database->IsDatabaseLoaded(database_filename); +} + +/** + * @param[in] wallet_path Path to wallet directory. Or (for backwards compatibility only) a path to a berkeley btree data file inside a wallet directory. + * @param[out] database_filename Filename of berkeley btree data file inside the wallet directory. + * @return A shared pointer to the BerkeleyEnvironment object for the wallet directory, never empty because ~BerkeleyEnvironment + * erases the weak pointer from the g_dbenvs map. + * @post A new BerkeleyEnvironment weak pointer is inserted into g_dbenvs if the directory path key was not already in the map. + */ +std::shared_ptr<BerkeleyEnvironment> GetWalletEnv(const fs::path& wallet_path, std::string& database_filename) +{ + fs::path env_directory; + SplitWalletPath(wallet_path, env_directory, database_filename); + LOCK(cs_db); + auto inserted = g_dbenvs.emplace(env_directory.string(), std::weak_ptr<BerkeleyEnvironment>()); + if (inserted.second) { + auto env = std::make_shared<BerkeleyEnvironment>(env_directory.string()); + inserted.first->second = env; + return env; + } + return inserted.first->second.lock(); +} + +// +// BerkeleyBatch +// + +void BerkeleyEnvironment::Close() +{ + if (!fDbEnvInit) + return; + + fDbEnvInit = false; + + for (auto& db : m_databases) { + BerkeleyDatabase& database = db.second.get(); + assert(database.m_refcount <= 0); + if (database.m_db) { + database.m_db->close(0); + database.m_db.reset(); + } + } + + FILE* error_file = nullptr; + dbenv->get_errfile(&error_file); + + int ret = dbenv->close(0); + if (ret != 0) + LogPrintf("BerkeleyEnvironment::Close: Error %d closing database environment: %s\n", ret, DbEnv::strerror(ret)); + if (!fMockDb) + DbEnv((u_int32_t)0).remove(strPath.c_str(), 0); + + if (error_file) fclose(error_file); + + UnlockDirectory(strPath, ".walletlock"); +} + +void BerkeleyEnvironment::Reset() +{ + dbenv.reset(new DbEnv(DB_CXX_NO_EXCEPTIONS)); + fDbEnvInit = false; + fMockDb = false; +} + +BerkeleyEnvironment::BerkeleyEnvironment(const fs::path& dir_path) : strPath(dir_path.string()) +{ + Reset(); +} + +BerkeleyEnvironment::~BerkeleyEnvironment() +{ + LOCK(cs_db); + g_dbenvs.erase(strPath); + Close(); +} + +bool BerkeleyEnvironment::Open(bilingual_str& err) +{ + if (fDbEnvInit) { + return true; + } + + fs::path pathIn = strPath; + TryCreateDirectories(pathIn); + if (!LockDirectory(pathIn, ".walletlock")) { + LogPrintf("Cannot obtain a lock on wallet directory %s. Another instance of bitcoin may be using it.\n", strPath); + err = strprintf(_("Error initializing wallet database environment %s!"), Directory()); + return false; + } + + fs::path pathLogDir = pathIn / "database"; + TryCreateDirectories(pathLogDir); + fs::path pathErrorFile = pathIn / "db.log"; + LogPrintf("BerkeleyEnvironment::Open: LogDir=%s ErrorFile=%s\n", pathLogDir.string(), pathErrorFile.string()); + + unsigned int nEnvFlags = 0; + if (gArgs.GetBoolArg("-privdb", DEFAULT_WALLET_PRIVDB)) + nEnvFlags |= DB_PRIVATE; + + dbenv->set_lg_dir(pathLogDir.string().c_str()); + dbenv->set_cachesize(0, 0x100000, 1); // 1 MiB should be enough for just the wallet + dbenv->set_lg_bsize(0x10000); + dbenv->set_lg_max(1048576); + dbenv->set_lk_max_locks(40000); + dbenv->set_lk_max_objects(40000); + dbenv->set_errfile(fsbridge::fopen(pathErrorFile, "a")); /// debug + dbenv->set_flags(DB_AUTO_COMMIT, 1); + dbenv->set_flags(DB_TXN_WRITE_NOSYNC, 1); + dbenv->log_set_config(DB_LOG_AUTO_REMOVE, 1); + int ret = dbenv->open(strPath.c_str(), + DB_CREATE | + DB_INIT_LOCK | + DB_INIT_LOG | + DB_INIT_MPOOL | + DB_INIT_TXN | + DB_THREAD | + DB_RECOVER | + nEnvFlags, + S_IRUSR | S_IWUSR); + if (ret != 0) { + LogPrintf("BerkeleyEnvironment::Open: Error %d opening database environment: %s\n", ret, DbEnv::strerror(ret)); + int ret2 = dbenv->close(0); + if (ret2 != 0) { + LogPrintf("BerkeleyEnvironment::Open: Error %d closing failed database environment: %s\n", ret2, DbEnv::strerror(ret2)); + } + Reset(); + err = strprintf(_("Error initializing wallet database environment %s!"), Directory()); + if (ret == DB_RUNRECOVERY) { + err += Untranslated(" ") + _("This error could occur if this wallet was not shutdown cleanly and was last loaded using a build with a newer version of Berkeley DB. If so, please use the software that last loaded this wallet"); + } + return false; + } + + fDbEnvInit = true; + fMockDb = false; + return true; +} + +//! Construct an in-memory mock Berkeley environment for testing +BerkeleyEnvironment::BerkeleyEnvironment() +{ + Reset(); + + LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::MakeMock\n"); + + dbenv->set_cachesize(1, 0, 1); + dbenv->set_lg_bsize(10485760 * 4); + dbenv->set_lg_max(10485760); + dbenv->set_lk_max_locks(10000); + dbenv->set_lk_max_objects(10000); + dbenv->set_flags(DB_AUTO_COMMIT, 1); + dbenv->log_set_config(DB_LOG_IN_MEMORY, 1); + int ret = dbenv->open(nullptr, + DB_CREATE | + DB_INIT_LOCK | + DB_INIT_LOG | + DB_INIT_MPOOL | + DB_INIT_TXN | + DB_THREAD | + DB_PRIVATE, + S_IRUSR | S_IWUSR); + if (ret > 0) { + throw std::runtime_error(strprintf("BerkeleyEnvironment::MakeMock: Error %d opening database environment.", ret)); + } + + fDbEnvInit = true; + fMockDb = true; +} + +BerkeleyBatch::SafeDbt::SafeDbt() +{ + m_dbt.set_flags(DB_DBT_MALLOC); +} + +BerkeleyBatch::SafeDbt::SafeDbt(void* data, size_t size) + : m_dbt(data, size) +{ +} + +BerkeleyBatch::SafeDbt::~SafeDbt() +{ + if (m_dbt.get_data() != nullptr) { + // Clear memory, e.g. in case it was a private key + memory_cleanse(m_dbt.get_data(), m_dbt.get_size()); + // under DB_DBT_MALLOC, data is malloced by the Dbt, but must be + // freed by the caller. + // https://docs.oracle.com/cd/E17275_01/html/api_reference/C/dbt.html + if (m_dbt.get_flags() & DB_DBT_MALLOC) { + free(m_dbt.get_data()); + } + } +} + +const void* BerkeleyBatch::SafeDbt::get_data() const +{ + return m_dbt.get_data(); +} + +u_int32_t BerkeleyBatch::SafeDbt::get_size() const +{ + return m_dbt.get_size(); +} + +BerkeleyBatch::SafeDbt::operator Dbt*() +{ + return &m_dbt; +} + +bool BerkeleyDatabase::Verify(bilingual_str& errorStr) +{ + fs::path walletDir = env->Directory(); + fs::path file_path = walletDir / strFile; + + LogPrintf("Using BerkeleyDB version %s\n", BerkeleyDatabaseVersion()); + LogPrintf("Using wallet %s\n", file_path.string()); + + if (!env->Open(errorStr)) { + return false; + } + + if (fs::exists(file_path)) + { + assert(m_refcount == 0); + + Db db(env->dbenv.get(), 0); + int result = db.verify(strFile.c_str(), nullptr, nullptr, 0); + if (result != 0) { + errorStr = strprintf(_("%s corrupt. Try using the wallet tool bitcoin-wallet to salvage or restoring a backup."), file_path); + return false; + } + } + // also return true if files does not exists + return true; +} + +void BerkeleyEnvironment::CheckpointLSN(const std::string& strFile) +{ + dbenv->txn_checkpoint(0, 0, 0); + if (fMockDb) + return; + dbenv->lsn_reset(strFile.c_str(), 0); +} + +BerkeleyDatabase::~BerkeleyDatabase() +{ + if (env) { + LOCK(cs_db); + env->CloseDb(strFile); + assert(!m_db); + size_t erased = env->m_databases.erase(strFile); + assert(erased == 1); + env->m_fileids.erase(strFile); + } +} + +BerkeleyBatch::BerkeleyBatch(BerkeleyDatabase& database, const char* pszMode, bool fFlushOnCloseIn) : pdb(nullptr), activeTxn(nullptr), m_cursor(nullptr), m_database(database) +{ + database.AddRef(); + database.Open(pszMode); + fReadOnly = (!strchr(pszMode, '+') && !strchr(pszMode, 'w')); + fFlushOnClose = fFlushOnCloseIn; + env = database.env.get(); + pdb = database.m_db.get(); + strFile = database.strFile; + bool fCreate = strchr(pszMode, 'c') != nullptr; + if (fCreate && !Exists(std::string("version"))) { + bool fTmp = fReadOnly; + fReadOnly = false; + Write(std::string("version"), CLIENT_VERSION); + fReadOnly = fTmp; + } +} + +void BerkeleyDatabase::Open(const char* pszMode) +{ + bool fCreate = strchr(pszMode, 'c') != nullptr; + unsigned int nFlags = DB_THREAD; + if (fCreate) + nFlags |= DB_CREATE; + + { + LOCK(cs_db); + bilingual_str open_err; + if (!env->Open(open_err)) + throw std::runtime_error("BerkeleyDatabase: Failed to open database environment."); + + if (m_db == nullptr) { + int ret; + std::unique_ptr<Db> pdb_temp = MakeUnique<Db>(env->dbenv.get(), 0); + + bool fMockDb = env->IsMock(); + if (fMockDb) { + DbMpoolFile* mpf = pdb_temp->get_mpf(); + ret = mpf->set_flags(DB_MPOOL_NOFILE, 1); + if (ret != 0) { + throw std::runtime_error(strprintf("BerkeleyDatabase: Failed to configure for no temp file backing for database %s", strFile)); + } + } + + ret = pdb_temp->open(nullptr, // Txn pointer + fMockDb ? nullptr : strFile.c_str(), // Filename + fMockDb ? strFile.c_str() : "main", // Logical db name + DB_BTREE, // Database type + nFlags, // Flags + 0); + + if (ret != 0) { + throw std::runtime_error(strprintf("BerkeleyDatabase: Error %d, can't open database %s", ret, strFile)); + } + m_file_path = (env->Directory() / strFile).string(); + + // Call CheckUniqueFileid on the containing BDB environment to + // avoid BDB data consistency bugs that happen when different data + // files in the same environment have the same fileid. + CheckUniqueFileid(*env, strFile, *pdb_temp, this->env->m_fileids[strFile]); + + m_db.reset(pdb_temp.release()); + + } + } +} + +void BerkeleyBatch::Flush() +{ + if (activeTxn) + return; + + // Flush database activity from memory pool to disk log + unsigned int nMinutes = 0; + if (fReadOnly) + nMinutes = 1; + + if (env) { // env is nullptr for dummy databases (i.e. in tests). Don't actually flush if env is nullptr so we don't segfault + env->dbenv->txn_checkpoint(nMinutes ? gArgs.GetArg("-dblogsize", DEFAULT_WALLET_DBLOGSIZE) * 1024 : 0, nMinutes, 0); + } +} + +void BerkeleyDatabase::IncrementUpdateCounter() +{ + ++nUpdateCounter; +} + +BerkeleyBatch::~BerkeleyBatch() +{ + Close(); + m_database.RemoveRef(); +} + +void BerkeleyBatch::Close() +{ + if (!pdb) + return; + if (activeTxn) + activeTxn->abort(); + activeTxn = nullptr; + pdb = nullptr; + CloseCursor(); + + if (fFlushOnClose) + Flush(); +} + +void BerkeleyEnvironment::CloseDb(const std::string& strFile) +{ + { + LOCK(cs_db); + auto it = m_databases.find(strFile); + assert(it != m_databases.end()); + BerkeleyDatabase& database = it->second.get(); + if (database.m_db) { + // Close the database handle + database.m_db->close(0); + database.m_db.reset(); + } + } +} + +void BerkeleyEnvironment::ReloadDbEnv() +{ + // Make sure that no Db's are in use + AssertLockNotHeld(cs_db); + std::unique_lock<RecursiveMutex> lock(cs_db); + m_db_in_use.wait(lock, [this](){ + for (auto& db : m_databases) { + if (db.second.get().m_refcount > 0) return false; + } + return true; + }); + + std::vector<std::string> filenames; + for (auto it : m_databases) { + filenames.push_back(it.first); + } + // Close the individual Db's + for (const std::string& filename : filenames) { + CloseDb(filename); + } + // Reset the environment + Flush(true); // This will flush and close the environment + Reset(); + bilingual_str open_err; + Open(open_err); +} + +bool BerkeleyDatabase::Rewrite(const char* pszSkip) +{ + while (true) { + { + LOCK(cs_db); + if (m_refcount <= 0) { + // Flush log data to the dat file + env->CloseDb(strFile); + env->CheckpointLSN(strFile); + m_refcount = -1; + + bool fSuccess = true; + LogPrintf("BerkeleyBatch::Rewrite: Rewriting %s...\n", strFile); + std::string strFileRes = strFile + ".rewrite"; + { // surround usage of db with extra {} + BerkeleyBatch db(*this, "r"); + std::unique_ptr<Db> pdbCopy = MakeUnique<Db>(env->dbenv.get(), 0); + + int ret = pdbCopy->open(nullptr, // Txn pointer + strFileRes.c_str(), // Filename + "main", // Logical db name + DB_BTREE, // Database type + DB_CREATE, // Flags + 0); + if (ret > 0) { + LogPrintf("BerkeleyBatch::Rewrite: Can't create database file %s\n", strFileRes); + fSuccess = false; + } + + if (db.StartCursor()) { + while (fSuccess) { + CDataStream ssKey(SER_DISK, CLIENT_VERSION); + CDataStream ssValue(SER_DISK, CLIENT_VERSION); + bool complete; + bool ret1 = db.ReadAtCursor(ssKey, ssValue, complete); + if (complete) { + break; + } else if (!ret1) { + fSuccess = false; + break; + } + if (pszSkip && + strncmp(ssKey.data(), pszSkip, std::min(ssKey.size(), strlen(pszSkip))) == 0) + continue; + if (strncmp(ssKey.data(), "\x07version", 8) == 0) { + // Update version: + ssValue.clear(); + ssValue << CLIENT_VERSION; + } + Dbt datKey(ssKey.data(), ssKey.size()); + Dbt datValue(ssValue.data(), ssValue.size()); + int ret2 = pdbCopy->put(nullptr, &datKey, &datValue, DB_NOOVERWRITE); + if (ret2 > 0) + fSuccess = false; + } + db.CloseCursor(); + } + if (fSuccess) { + db.Close(); + env->CloseDb(strFile); + if (pdbCopy->close(0)) + fSuccess = false; + } else { + pdbCopy->close(0); + } + } + if (fSuccess) { + Db dbA(env->dbenv.get(), 0); + if (dbA.remove(strFile.c_str(), nullptr, 0)) + fSuccess = false; + Db dbB(env->dbenv.get(), 0); + if (dbB.rename(strFileRes.c_str(), nullptr, strFile.c_str(), 0)) + fSuccess = false; + } + if (!fSuccess) + LogPrintf("BerkeleyBatch::Rewrite: Failed to rewrite database file %s\n", strFileRes); + return fSuccess; + } + } + UninterruptibleSleep(std::chrono::milliseconds{100}); + } +} + + +void BerkeleyEnvironment::Flush(bool fShutdown) +{ + int64_t nStart = GetTimeMillis(); + // Flush log data to the actual data file on all files that are not in use + LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: [%s] Flush(%s)%s\n", strPath, fShutdown ? "true" : "false", fDbEnvInit ? "" : " database not started"); + if (!fDbEnvInit) + return; + { + LOCK(cs_db); + bool no_dbs_accessed = true; + for (auto& db_it : m_databases) { + std::string strFile = db_it.first; + int nRefCount = db_it.second.get().m_refcount; + if (nRefCount < 0) continue; + LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: Flushing %s (refcount = %d)...\n", strFile, nRefCount); + if (nRefCount == 0) { + // Move log data to the dat file + CloseDb(strFile); + LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: %s checkpoint\n", strFile); + dbenv->txn_checkpoint(0, 0, 0); + LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: %s detach\n", strFile); + if (!fMockDb) + dbenv->lsn_reset(strFile.c_str(), 0); + LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: %s closed\n", strFile); + nRefCount = -1; + } else { + no_dbs_accessed = false; + } + } + LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: Flush(%s)%s took %15dms\n", fShutdown ? "true" : "false", fDbEnvInit ? "" : " database not started", GetTimeMillis() - nStart); + if (fShutdown) { + char** listp; + if (no_dbs_accessed) { + dbenv->log_archive(&listp, DB_ARCH_REMOVE); + Close(); + if (!fMockDb) { + fs::remove_all(fs::path(strPath) / "database"); + } + } + } + } +} + +bool BerkeleyDatabase::PeriodicFlush() +{ + // Don't flush if we can't acquire the lock. + TRY_LOCK(cs_db, lockDb); + if (!lockDb) return false; + + // Don't flush if any databases are in use + for (auto& it : env->m_databases) { + if (it.second.get().m_refcount > 0) return false; + } + + // Don't flush if there haven't been any batch writes for this database. + if (m_refcount < 0) return false; + + LogPrint(BCLog::WALLETDB, "Flushing %s\n", strFile); + int64_t nStart = GetTimeMillis(); + + // Flush wallet file so it's self contained + env->CloseDb(strFile); + env->CheckpointLSN(strFile); + m_refcount = -1; + + LogPrint(BCLog::WALLETDB, "Flushed %s %dms\n", strFile, GetTimeMillis() - nStart); + + return true; +} + +bool BerkeleyDatabase::Backup(const std::string& strDest) const +{ + while (true) + { + { + LOCK(cs_db); + if (m_refcount <= 0) + { + // Flush log data to the dat file + env->CloseDb(strFile); + env->CheckpointLSN(strFile); + + // Copy wallet file + fs::path pathSrc = env->Directory() / strFile; + fs::path pathDest(strDest); + if (fs::is_directory(pathDest)) + pathDest /= strFile; + + try { + if (fs::equivalent(pathSrc, pathDest)) { + LogPrintf("cannot backup to wallet source file %s\n", pathDest.string()); + return false; + } + + fs::copy_file(pathSrc, pathDest, fs::copy_option::overwrite_if_exists); + LogPrintf("copied %s to %s\n", strFile, pathDest.string()); + return true; + } catch (const fs::filesystem_error& e) { + LogPrintf("error copying %s to %s - %s\n", strFile, pathDest.string(), fsbridge::get_filesystem_error_message(e)); + return false; + } + } + } + UninterruptibleSleep(std::chrono::milliseconds{100}); + } +} + +void BerkeleyDatabase::Flush() +{ + env->Flush(false); +} + +void BerkeleyDatabase::Close() +{ + env->Flush(true); +} + +void BerkeleyDatabase::ReloadDbEnv() +{ + env->ReloadDbEnv(); +} + +bool BerkeleyBatch::StartCursor() +{ + assert(!m_cursor); + if (!pdb) + return false; + int ret = pdb->cursor(nullptr, &m_cursor, 0); + return ret == 0; +} + +bool BerkeleyBatch::ReadAtCursor(CDataStream& ssKey, CDataStream& ssValue, bool& complete) +{ + complete = false; + if (m_cursor == nullptr) return false; + // Read at cursor + SafeDbt datKey; + SafeDbt datValue; + int ret = m_cursor->get(datKey, datValue, DB_NEXT); + if (ret == DB_NOTFOUND) { + complete = true; + } + if (ret != 0) + return false; + else if (datKey.get_data() == nullptr || datValue.get_data() == nullptr) + return false; + + // Convert to streams + ssKey.SetType(SER_DISK); + ssKey.clear(); + ssKey.write((char*)datKey.get_data(), datKey.get_size()); + ssValue.SetType(SER_DISK); + ssValue.clear(); + ssValue.write((char*)datValue.get_data(), datValue.get_size()); + return true; +} + +void BerkeleyBatch::CloseCursor() +{ + if (!m_cursor) return; + m_cursor->close(); + m_cursor = nullptr; +} + +bool BerkeleyBatch::TxnBegin() +{ + if (!pdb || activeTxn) + return false; + DbTxn* ptxn = env->TxnBegin(); + if (!ptxn) + return false; + activeTxn = ptxn; + return true; +} + +bool BerkeleyBatch::TxnCommit() +{ + if (!pdb || !activeTxn) + return false; + int ret = activeTxn->commit(0); + activeTxn = nullptr; + return (ret == 0); +} + +bool BerkeleyBatch::TxnAbort() +{ + if (!pdb || !activeTxn) + return false; + int ret = activeTxn->abort(); + activeTxn = nullptr; + return (ret == 0); +} + +std::string BerkeleyDatabaseVersion() +{ + return DbEnv::version(nullptr, nullptr, nullptr); +} + +bool BerkeleyBatch::ReadKey(CDataStream&& key, CDataStream& value) +{ + if (!pdb) + return false; + + SafeDbt datKey(key.data(), key.size()); + + SafeDbt datValue; + int ret = pdb->get(activeTxn, datKey, datValue, 0); + if (ret == 0 && datValue.get_data() != nullptr) { + value.write((char*)datValue.get_data(), datValue.get_size()); + return true; + } + return false; +} + +bool BerkeleyBatch::WriteKey(CDataStream&& key, CDataStream&& value, bool overwrite) +{ + if (!pdb) + return false; + if (fReadOnly) + assert(!"Write called on database in read-only mode"); + + SafeDbt datKey(key.data(), key.size()); + + SafeDbt datValue(value.data(), value.size()); + + int ret = pdb->put(activeTxn, datKey, datValue, (overwrite ? 0 : DB_NOOVERWRITE)); + return (ret == 0); +} + +bool BerkeleyBatch::EraseKey(CDataStream&& key) +{ + if (!pdb) + return false; + if (fReadOnly) + assert(!"Erase called on database in read-only mode"); + + SafeDbt datKey(key.data(), key.size()); + + int ret = pdb->del(activeTxn, datKey, 0); + return (ret == 0 || ret == DB_NOTFOUND); +} + +bool BerkeleyBatch::HasKey(CDataStream&& key) +{ + if (!pdb) + return false; + + SafeDbt datKey(key.data(), key.size()); + + int ret = pdb->exists(activeTxn, datKey, 0); + return ret == 0; +} + +void BerkeleyDatabase::AddRef() +{ + LOCK(cs_db); + if (m_refcount < 0) { + m_refcount = 1; + } else { + m_refcount++; + } +} + +void BerkeleyDatabase::RemoveRef() +{ + LOCK(cs_db); + m_refcount--; + if (env) env->m_db_in_use.notify_all(); +} + +std::unique_ptr<DatabaseBatch> BerkeleyDatabase::MakeBatch(const char* mode, bool flush_on_close) +{ + return MakeUnique<BerkeleyBatch>(*this, mode, flush_on_close); +} diff --git a/src/wallet/bdb.h b/src/wallet/bdb.h new file mode 100644 index 0000000000..75546924e8 --- /dev/null +++ b/src/wallet/bdb.h @@ -0,0 +1,227 @@ +// Copyright (c) 2009-2010 Satoshi Nakamoto +// Copyright (c) 2009-2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_WALLET_BDB_H +#define BITCOIN_WALLET_BDB_H + +#include <clientversion.h> +#include <fs.h> +#include <serialize.h> +#include <streams.h> +#include <util/system.h> +#include <wallet/db.h> + +#include <atomic> +#include <map> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsuggest-override" +#endif +#include <db_cxx.h> +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + +struct bilingual_str; + +static const unsigned int DEFAULT_WALLET_DBLOGSIZE = 100; +static const bool DEFAULT_WALLET_PRIVDB = true; + +struct WalletDatabaseFileId { + u_int8_t value[DB_FILE_ID_LEN]; + bool operator==(const WalletDatabaseFileId& rhs) const; +}; + +class BerkeleyDatabase; + +class BerkeleyEnvironment +{ +private: + bool fDbEnvInit; + bool fMockDb; + // Don't change into fs::path, as that can result in + // shutdown problems/crashes caused by a static initialized internal pointer. + std::string strPath; + +public: + std::unique_ptr<DbEnv> dbenv; + std::map<std::string, std::reference_wrapper<BerkeleyDatabase>> m_databases; + std::unordered_map<std::string, WalletDatabaseFileId> m_fileids; + std::condition_variable_any m_db_in_use; + + BerkeleyEnvironment(const fs::path& env_directory); + BerkeleyEnvironment(); + ~BerkeleyEnvironment(); + void Reset(); + + bool IsMock() const { return fMockDb; } + bool IsInitialized() const { return fDbEnvInit; } + bool IsDatabaseLoaded(const std::string& db_filename) const { return m_databases.find(db_filename) != m_databases.end(); } + fs::path Directory() const { return strPath; } + + bool Open(bilingual_str& error); + void Close(); + void Flush(bool fShutdown); + void CheckpointLSN(const std::string& strFile); + + void CloseDb(const std::string& strFile); + void ReloadDbEnv(); + + DbTxn* TxnBegin(int flags = DB_TXN_WRITE_NOSYNC) + { + DbTxn* ptxn = nullptr; + int ret = dbenv->txn_begin(nullptr, &ptxn, flags); + if (!ptxn || ret != 0) + return nullptr; + return ptxn; + } +}; + +/** Get BerkeleyEnvironment and database filename given a wallet path. */ +std::shared_ptr<BerkeleyEnvironment> GetWalletEnv(const fs::path& wallet_path, std::string& database_filename); + +/** Return whether a BDB wallet database is currently loaded. */ +bool IsBDBWalletLoaded(const fs::path& wallet_path); + +class BerkeleyBatch; + +/** An instance of this class represents one database. + * For BerkeleyDB this is just a (env, strFile) tuple. + **/ +class BerkeleyDatabase : public WalletDatabase +{ +public: + BerkeleyDatabase() = delete; + + /** Create DB handle to real database */ + BerkeleyDatabase(std::shared_ptr<BerkeleyEnvironment> env, std::string filename) : + WalletDatabase(), env(std::move(env)), strFile(std::move(filename)) + { + auto inserted = this->env->m_databases.emplace(strFile, std::ref(*this)); + assert(inserted.second); + } + + ~BerkeleyDatabase() override; + + /** Open the database if it is not already opened. + * Dummy function, doesn't do anything right now, but is needed for class abstraction */ + void Open(const char* mode) override; + + /** Rewrite the entire database on disk, with the exception of key pszSkip if non-zero + */ + bool Rewrite(const char* pszSkip=nullptr) override; + + /** Indicate the a new database user has began using the database. */ + void AddRef() override; + /** Indicate that database user has stopped using the database and that it could be flushed or closed. */ + void RemoveRef() override; + + /** Back up the entire database to a file. + */ + bool Backup(const std::string& strDest) const override; + + /** Make sure all changes are flushed to database file. + */ + void Flush() override; + /** Flush to the database file and close the database. + * Also close the environment if no other databases are open in it. + */ + void Close() override; + /* flush the wallet passively (TRY_LOCK) + ideal to be called periodically */ + bool PeriodicFlush() override; + + void IncrementUpdateCounter() override; + + void ReloadDbEnv() override; + + /** Verifies the environment and database file */ + bool Verify(bilingual_str& error) override; + + /** + * Pointer to shared database environment. + * + * Normally there is only one BerkeleyDatabase object per + * BerkeleyEnvivonment, but in the special, backwards compatible case where + * multiple wallet BDB data files are loaded from the same directory, this + * will point to a shared instance that gets freed when the last data file + * is closed. + */ + std::shared_ptr<BerkeleyEnvironment> env; + + /** Database pointer. This is initialized lazily and reset during flushes, so it can be null. */ + std::unique_ptr<Db> m_db; + + std::string strFile; + + /** Make a BerkeleyBatch connected to this database */ + std::unique_ptr<DatabaseBatch> MakeBatch(const char* mode = "r+", bool flush_on_close = true) override; +}; + +/** RAII class that provides access to a Berkeley database */ +class BerkeleyBatch : public DatabaseBatch +{ + /** RAII class that automatically cleanses its data on destruction */ + class SafeDbt final + { + Dbt m_dbt; + + public: + // construct Dbt with internally-managed data + SafeDbt(); + // construct Dbt with provided data + SafeDbt(void* data, size_t size); + ~SafeDbt(); + + // delegate to Dbt + const void* get_data() const; + u_int32_t get_size() const; + + // conversion operator to access the underlying Dbt + operator Dbt*(); + }; + +private: + bool ReadKey(CDataStream&& key, CDataStream& value) override; + bool WriteKey(CDataStream&& key, CDataStream&& value, bool overwrite = true) override; + bool EraseKey(CDataStream&& key) override; + bool HasKey(CDataStream&& key) override; + +protected: + Db* pdb; + std::string strFile; + DbTxn* activeTxn; + Dbc* m_cursor; + bool fReadOnly; + bool fFlushOnClose; + BerkeleyEnvironment *env; + BerkeleyDatabase& m_database; + +public: + explicit BerkeleyBatch(BerkeleyDatabase& database, const char* pszMode = "r+", bool fFlushOnCloseIn=true); + ~BerkeleyBatch() override; + + BerkeleyBatch(const BerkeleyBatch&) = delete; + BerkeleyBatch& operator=(const BerkeleyBatch&) = delete; + + void Flush() override; + void Close() override; + + bool StartCursor() override; + bool ReadAtCursor(CDataStream& ssKey, CDataStream& ssValue, bool& complete) override; + void CloseCursor() override; + bool TxnBegin() override; + bool TxnCommit() override; + bool TxnAbort() override; +}; + +std::string BerkeleyDatabaseVersion(); + +#endif // BITCOIN_WALLET_BDB_H diff --git a/src/wallet/coincontrol.cpp b/src/wallet/coincontrol.cpp index c83e598825..720877ead0 100644 --- a/src/wallet/coincontrol.cpp +++ b/src/wallet/coincontrol.cpp @@ -10,6 +10,7 @@ void CCoinControl::SetNull() { destChange = CNoDestination(); m_change_type.reset(); + m_add_inputs = true; fAllowOtherInputs = false; fAllowWatchOnly = false; m_avoid_partial_spends = gArgs.GetBoolArg("-avoidpartialspends", DEFAULT_AVOIDPARTIALSPENDS); @@ -23,4 +24,3 @@ void CCoinControl::SetNull() m_min_depth = DEFAULT_MIN_DEPTH; m_max_depth = DEFAULT_MAX_DEPTH; } - diff --git a/src/wallet/coincontrol.h b/src/wallet/coincontrol.h index 2893d0ab3d..c499b0ff25 100644 --- a/src/wallet/coincontrol.h +++ b/src/wallet/coincontrol.h @@ -26,6 +26,8 @@ public: CTxDestination destChange; //! Override the default change type if set, ignored if destChange is set Optional<OutputType> m_change_type; + //! If false, only selected inputs are used + bool m_add_inputs; //! If false, allows unselected inputs, but requires all selected inputs be used bool fAllowOtherInputs; //! Includes watch only addresses which are solvable diff --git a/src/wallet/context.cpp b/src/wallet/context.cpp new file mode 100644 index 0000000000..09b2f30467 --- /dev/null +++ b/src/wallet/context.cpp @@ -0,0 +1,8 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include <wallet/context.h> + +WalletContext::WalletContext() {} +WalletContext::~WalletContext() {} diff --git a/src/wallet/context.h b/src/wallet/context.h new file mode 100644 index 0000000000..a83591154f --- /dev/null +++ b/src/wallet/context.h @@ -0,0 +1,34 @@ +// Copyright (c) 2020 The Bitcoin Core developers +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_WALLET_CONTEXT_H +#define BITCOIN_WALLET_CONTEXT_H + +class ArgsManager; +namespace interfaces { +class Chain; +} // namespace interfaces + +//! WalletContext struct containing references to state shared between CWallet +//! instances, like the reference to the chain interface, and the list of opened +//! wallets. +//! +//! Future shared state can be added here as an alternative to adding global +//! variables. +//! +//! The struct isn't intended to have any member functions. It should just be a +//! collection of state pointers that doesn't pull in dependencies or implement +//! behavior. +struct WalletContext { + interfaces::Chain* chain{nullptr}; + ArgsManager* args{nullptr}; + + //! Declare default constructor and destructor that are not inline, so code + //! instantiating the WalletContext struct doesn't need to #include class + //! definitions for smart pointer and container members. + WalletContext(); + ~WalletContext(); +}; + +#endif // BITCOIN_WALLET_CONTEXT_H diff --git a/src/wallet/db.cpp b/src/wallet/db.cpp index 4ed28b0623..1eb82a03c7 100644 --- a/src/wallet/db.cpp +++ b/src/wallet/db.cpp @@ -3,57 +3,12 @@ // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. +#include <fs.h> #include <wallet/db.h> -#include <util/strencodings.h> -#include <util/translation.h> +#include <string> -#include <stdint.h> - -#ifndef WIN32 -#include <sys/stat.h> -#endif - -#include <boost/thread.hpp> - -namespace { - -//! Make sure database has a unique fileid within the environment. If it -//! doesn't, throw an error. BDB caches do not work properly when more than one -//! open database has the same fileid (values written to one database may show -//! up in reads to other databases). -//! -//! BerkeleyDB generates unique fileids by default -//! (https://docs.oracle.com/cd/E17275_01/html/programmer_reference/program_copy.html), -//! so bitcoin should never create different databases with the same fileid, but -//! this error can be triggered if users manually copy database files. -void CheckUniqueFileid(const BerkeleyEnvironment& env, const std::string& filename, Db& db, WalletDatabaseFileId& fileid) -{ - if (env.IsMock()) return; - - int ret = db.get_mpf()->get_fileid(fileid.value); - if (ret != 0) { - throw std::runtime_error(strprintf("BerkeleyBatch: Can't open database %s (get_fileid failed with %d)", filename, ret)); - } - - for (const auto& item : env.m_fileids) { - if (fileid == item.second && &fileid != &item.second) { - throw std::runtime_error(strprintf("BerkeleyBatch: Can't open database %s (duplicates fileid %s from %s)", filename, - HexStr(std::begin(item.second.value), std::end(item.second.value)), item.first)); - } - } -} - -RecursiveMutex cs_db; -std::map<std::string, std::weak_ptr<BerkeleyEnvironment>> g_dbenvs GUARDED_BY(cs_db); //!< Map from directory name to db environment. -} // namespace - -bool WalletDatabaseFileId::operator==(const WalletDatabaseFileId& rhs) const -{ - return memcmp(value, &rhs.value, sizeof(value)) == 0; -} - -static void SplitWalletPath(const fs::path& wallet_path, fs::path& env_directory, std::string& database_filename) +void SplitWalletPath(const fs::path& wallet_path, fs::path& env_directory, std::string& database_filename) { if (fs::is_regular_file(wallet_path)) { // Special case for backwards compatibility: if wallet path points to an @@ -69,18 +24,6 @@ static void SplitWalletPath(const fs::path& wallet_path, fs::path& env_directory } } -bool IsWalletLoaded(const fs::path& wallet_path) -{ - fs::path env_directory; - std::string database_filename; - SplitWalletPath(wallet_path, env_directory, database_filename); - LOCK(cs_db); - auto env = g_dbenvs.find(env_directory.string()); - if (env == g_dbenvs.end()) return false; - auto database = env->second.lock(); - return database && database->IsDatabaseLoaded(database_filename); -} - fs::path WalletDataFilePath(const fs::path& wallet_path) { fs::path env_directory; @@ -88,683 +31,3 @@ fs::path WalletDataFilePath(const fs::path& wallet_path) SplitWalletPath(wallet_path, env_directory, database_filename); return env_directory / database_filename; } - -/** - * @param[in] wallet_path Path to wallet directory. Or (for backwards compatibility only) a path to a berkeley btree data file inside a wallet directory. - * @param[out] database_filename Filename of berkeley btree data file inside the wallet directory. - * @return A shared pointer to the BerkeleyEnvironment object for the wallet directory, never empty because ~BerkeleyEnvironment - * erases the weak pointer from the g_dbenvs map. - * @post A new BerkeleyEnvironment weak pointer is inserted into g_dbenvs if the directory path key was not already in the map. - */ -std::shared_ptr<BerkeleyEnvironment> GetWalletEnv(const fs::path& wallet_path, std::string& database_filename) -{ - fs::path env_directory; - SplitWalletPath(wallet_path, env_directory, database_filename); - LOCK(cs_db); - auto inserted = g_dbenvs.emplace(env_directory.string(), std::weak_ptr<BerkeleyEnvironment>()); - if (inserted.second) { - auto env = std::make_shared<BerkeleyEnvironment>(env_directory.string()); - inserted.first->second = env; - return env; - } - return inserted.first->second.lock(); -} - -// -// BerkeleyBatch -// - -void BerkeleyEnvironment::Close() -{ - if (!fDbEnvInit) - return; - - fDbEnvInit = false; - - for (auto& db : m_databases) { - auto count = mapFileUseCount.find(db.first); - assert(count == mapFileUseCount.end() || count->second == 0); - BerkeleyDatabase& database = db.second.get(); - if (database.m_db) { - database.m_db->close(0); - database.m_db.reset(); - } - } - - FILE* error_file = nullptr; - dbenv->get_errfile(&error_file); - - int ret = dbenv->close(0); - if (ret != 0) - LogPrintf("BerkeleyEnvironment::Close: Error %d closing database environment: %s\n", ret, DbEnv::strerror(ret)); - if (!fMockDb) - DbEnv((u_int32_t)0).remove(strPath.c_str(), 0); - - if (error_file) fclose(error_file); - - UnlockDirectory(strPath, ".walletlock"); -} - -void BerkeleyEnvironment::Reset() -{ - dbenv.reset(new DbEnv(DB_CXX_NO_EXCEPTIONS)); - fDbEnvInit = false; - fMockDb = false; -} - -BerkeleyEnvironment::BerkeleyEnvironment(const fs::path& dir_path) : strPath(dir_path.string()) -{ - Reset(); -} - -BerkeleyEnvironment::~BerkeleyEnvironment() -{ - LOCK(cs_db); - g_dbenvs.erase(strPath); - Close(); -} - -bool BerkeleyEnvironment::Open(bool retry) -{ - if (fDbEnvInit) { - return true; - } - - fs::path pathIn = strPath; - TryCreateDirectories(pathIn); - if (!LockDirectory(pathIn, ".walletlock")) { - LogPrintf("Cannot obtain a lock on wallet directory %s. Another instance of bitcoin may be using it.\n", strPath); - return false; - } - - fs::path pathLogDir = pathIn / "database"; - TryCreateDirectories(pathLogDir); - fs::path pathErrorFile = pathIn / "db.log"; - LogPrintf("BerkeleyEnvironment::Open: LogDir=%s ErrorFile=%s\n", pathLogDir.string(), pathErrorFile.string()); - - unsigned int nEnvFlags = 0; - if (gArgs.GetBoolArg("-privdb", DEFAULT_WALLET_PRIVDB)) - nEnvFlags |= DB_PRIVATE; - - dbenv->set_lg_dir(pathLogDir.string().c_str()); - dbenv->set_cachesize(0, 0x100000, 1); // 1 MiB should be enough for just the wallet - dbenv->set_lg_bsize(0x10000); - dbenv->set_lg_max(1048576); - dbenv->set_lk_max_locks(40000); - dbenv->set_lk_max_objects(40000); - dbenv->set_errfile(fsbridge::fopen(pathErrorFile, "a")); /// debug - dbenv->set_flags(DB_AUTO_COMMIT, 1); - dbenv->set_flags(DB_TXN_WRITE_NOSYNC, 1); - dbenv->log_set_config(DB_LOG_AUTO_REMOVE, 1); - int ret = dbenv->open(strPath.c_str(), - DB_CREATE | - DB_INIT_LOCK | - DB_INIT_LOG | - DB_INIT_MPOOL | - DB_INIT_TXN | - DB_THREAD | - DB_RECOVER | - nEnvFlags, - S_IRUSR | S_IWUSR); - if (ret != 0) { - LogPrintf("BerkeleyEnvironment::Open: Error %d opening database environment: %s\n", ret, DbEnv::strerror(ret)); - int ret2 = dbenv->close(0); - if (ret2 != 0) { - LogPrintf("BerkeleyEnvironment::Open: Error %d closing failed database environment: %s\n", ret2, DbEnv::strerror(ret2)); - } - Reset(); - if (retry) { - // try moving the database env out of the way - fs::path pathDatabaseBak = pathIn / strprintf("database.%d.bak", GetTime()); - try { - fs::rename(pathLogDir, pathDatabaseBak); - LogPrintf("Moved old %s to %s. Retrying.\n", pathLogDir.string(), pathDatabaseBak.string()); - } catch (const fs::filesystem_error&) { - // failure is ok (well, not really, but it's not worse than what we started with) - } - // try opening it again one more time - if (!Open(false /* retry */)) { - // if it still fails, it probably means we can't even create the database env - return false; - } - } else { - return false; - } - } - - fDbEnvInit = true; - fMockDb = false; - return true; -} - -//! Construct an in-memory mock Berkeley environment for testing -BerkeleyEnvironment::BerkeleyEnvironment() -{ - Reset(); - - LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::MakeMock\n"); - - dbenv->set_cachesize(1, 0, 1); - dbenv->set_lg_bsize(10485760 * 4); - dbenv->set_lg_max(10485760); - dbenv->set_lk_max_locks(10000); - dbenv->set_lk_max_objects(10000); - dbenv->set_flags(DB_AUTO_COMMIT, 1); - dbenv->log_set_config(DB_LOG_IN_MEMORY, 1); - int ret = dbenv->open(nullptr, - DB_CREATE | - DB_INIT_LOCK | - DB_INIT_LOG | - DB_INIT_MPOOL | - DB_INIT_TXN | - DB_THREAD | - DB_PRIVATE, - S_IRUSR | S_IWUSR); - if (ret > 0) { - throw std::runtime_error(strprintf("BerkeleyEnvironment::MakeMock: Error %d opening database environment.", ret)); - } - - fDbEnvInit = true; - fMockDb = true; -} - -bool BerkeleyEnvironment::Verify(const std::string& strFile) -{ - LOCK(cs_db); - assert(mapFileUseCount.count(strFile) == 0); - - Db db(dbenv.get(), 0); - int result = db.verify(strFile.c_str(), nullptr, nullptr, 0); - return result == 0; -} - -BerkeleyBatch::SafeDbt::SafeDbt() -{ - m_dbt.set_flags(DB_DBT_MALLOC); -} - -BerkeleyBatch::SafeDbt::SafeDbt(void* data, size_t size) - : m_dbt(data, size) -{ -} - -BerkeleyBatch::SafeDbt::~SafeDbt() -{ - if (m_dbt.get_data() != nullptr) { - // Clear memory, e.g. in case it was a private key - memory_cleanse(m_dbt.get_data(), m_dbt.get_size()); - // under DB_DBT_MALLOC, data is malloced by the Dbt, but must be - // freed by the caller. - // https://docs.oracle.com/cd/E17275_01/html/api_reference/C/dbt.html - if (m_dbt.get_flags() & DB_DBT_MALLOC) { - free(m_dbt.get_data()); - } - } -} - -const void* BerkeleyBatch::SafeDbt::get_data() const -{ - return m_dbt.get_data(); -} - -u_int32_t BerkeleyBatch::SafeDbt::get_size() const -{ - return m_dbt.get_size(); -} - -BerkeleyBatch::SafeDbt::operator Dbt*() -{ - return &m_dbt; -} - -bool BerkeleyBatch::VerifyEnvironment(const fs::path& file_path, bilingual_str& errorStr) -{ - std::string walletFile; - std::shared_ptr<BerkeleyEnvironment> env = GetWalletEnv(file_path, walletFile); - fs::path walletDir = env->Directory(); - - LogPrintf("Using BerkeleyDB version %s\n", BerkeleyDatabaseVersion()); - LogPrintf("Using wallet %s\n", file_path.string()); - - if (!env->Open(true /* retry */)) { - errorStr = strprintf(_("Error initializing wallet database environment %s!"), walletDir); - return false; - } - - return true; -} - -bool BerkeleyBatch::VerifyDatabaseFile(const fs::path& file_path, bilingual_str& errorStr) -{ - std::string walletFile; - std::shared_ptr<BerkeleyEnvironment> env = GetWalletEnv(file_path, walletFile); - fs::path walletDir = env->Directory(); - - if (fs::exists(walletDir / walletFile)) - { - if (!env->Verify(walletFile)) { - errorStr = strprintf(_("%s corrupt. Try using the wallet tool bitcoin-wallet to salvage or restoring a backup."), walletFile); - return false; - } - } - // also return true if files does not exists - return true; -} - -void BerkeleyEnvironment::CheckpointLSN(const std::string& strFile) -{ - dbenv->txn_checkpoint(0, 0, 0); - if (fMockDb) - return; - dbenv->lsn_reset(strFile.c_str(), 0); -} - - -BerkeleyBatch::BerkeleyBatch(BerkeleyDatabase& database, const char* pszMode, bool fFlushOnCloseIn) : pdb(nullptr), activeTxn(nullptr) -{ - fReadOnly = (!strchr(pszMode, '+') && !strchr(pszMode, 'w')); - fFlushOnClose = fFlushOnCloseIn; - env = database.env.get(); - if (database.IsDummy()) { - return; - } - const std::string &strFilename = database.strFile; - - bool fCreate = strchr(pszMode, 'c') != nullptr; - unsigned int nFlags = DB_THREAD; - if (fCreate) - nFlags |= DB_CREATE; - - { - LOCK(cs_db); - if (!env->Open(false /* retry */)) - throw std::runtime_error("BerkeleyBatch: Failed to open database environment."); - - pdb = database.m_db.get(); - if (pdb == nullptr) { - int ret; - std::unique_ptr<Db> pdb_temp = MakeUnique<Db>(env->dbenv.get(), 0); - - bool fMockDb = env->IsMock(); - if (fMockDb) { - DbMpoolFile* mpf = pdb_temp->get_mpf(); - ret = mpf->set_flags(DB_MPOOL_NOFILE, 1); - if (ret != 0) { - throw std::runtime_error(strprintf("BerkeleyBatch: Failed to configure for no temp file backing for database %s", strFilename)); - } - } - - ret = pdb_temp->open(nullptr, // Txn pointer - fMockDb ? nullptr : strFilename.c_str(), // Filename - fMockDb ? strFilename.c_str() : "main", // Logical db name - DB_BTREE, // Database type - nFlags, // Flags - 0); - - if (ret != 0) { - throw std::runtime_error(strprintf("BerkeleyBatch: Error %d, can't open database %s", ret, strFilename)); - } - - // Call CheckUniqueFileid on the containing BDB environment to - // avoid BDB data consistency bugs that happen when different data - // files in the same environment have the same fileid. - // - // Also call CheckUniqueFileid on all the other g_dbenvs to prevent - // bitcoin from opening the same data file through another - // environment when the file is referenced through equivalent but - // not obviously identical symlinked or hard linked or bind mounted - // paths. In the future a more relaxed check for equal inode and - // device ids could be done instead, which would allow opening - // different backup copies of a wallet at the same time. Maybe even - // more ideally, an exclusive lock for accessing the database could - // be implemented, so no equality checks are needed at all. (Newer - // versions of BDB have an set_lk_exclusive method for this - // purpose, but the older version we use does not.) - for (const auto& env : g_dbenvs) { - CheckUniqueFileid(*env.second.lock().get(), strFilename, *pdb_temp, this->env->m_fileids[strFilename]); - } - - pdb = pdb_temp.release(); - database.m_db.reset(pdb); - - if (fCreate && !Exists(std::string("version"))) { - bool fTmp = fReadOnly; - fReadOnly = false; - Write(std::string("version"), CLIENT_VERSION); - fReadOnly = fTmp; - } - } - ++env->mapFileUseCount[strFilename]; - strFile = strFilename; - } -} - -void BerkeleyBatch::Flush() -{ - if (activeTxn) - return; - - // Flush database activity from memory pool to disk log - unsigned int nMinutes = 0; - if (fReadOnly) - nMinutes = 1; - - if (env) { // env is nullptr for dummy databases (i.e. in tests). Don't actually flush if env is nullptr so we don't segfault - env->dbenv->txn_checkpoint(nMinutes ? gArgs.GetArg("-dblogsize", DEFAULT_WALLET_DBLOGSIZE) * 1024 : 0, nMinutes, 0); - } -} - -void BerkeleyDatabase::IncrementUpdateCounter() -{ - ++nUpdateCounter; -} - -void BerkeleyBatch::Close() -{ - if (!pdb) - return; - if (activeTxn) - activeTxn->abort(); - activeTxn = nullptr; - pdb = nullptr; - - if (fFlushOnClose) - Flush(); - - { - LOCK(cs_db); - --env->mapFileUseCount[strFile]; - } - env->m_db_in_use.notify_all(); -} - -void BerkeleyEnvironment::CloseDb(const std::string& strFile) -{ - { - LOCK(cs_db); - auto it = m_databases.find(strFile); - assert(it != m_databases.end()); - BerkeleyDatabase& database = it->second.get(); - if (database.m_db) { - // Close the database handle - database.m_db->close(0); - database.m_db.reset(); - } - } -} - -void BerkeleyEnvironment::ReloadDbEnv() -{ - // Make sure that no Db's are in use - AssertLockNotHeld(cs_db); - std::unique_lock<RecursiveMutex> lock(cs_db); - m_db_in_use.wait(lock, [this](){ - for (auto& count : mapFileUseCount) { - if (count.second > 0) return false; - } - return true; - }); - - std::vector<std::string> filenames; - for (auto it : m_databases) { - filenames.push_back(it.first); - } - // Close the individual Db's - for (const std::string& filename : filenames) { - CloseDb(filename); - } - // Reset the environment - Flush(true); // This will flush and close the environment - Reset(); - Open(true); -} - -bool BerkeleyBatch::Rewrite(BerkeleyDatabase& database, const char* pszSkip) -{ - if (database.IsDummy()) { - return true; - } - BerkeleyEnvironment *env = database.env.get(); - const std::string& strFile = database.strFile; - while (true) { - { - LOCK(cs_db); - if (!env->mapFileUseCount.count(strFile) || env->mapFileUseCount[strFile] == 0) { - // Flush log data to the dat file - env->CloseDb(strFile); - env->CheckpointLSN(strFile); - env->mapFileUseCount.erase(strFile); - - bool fSuccess = true; - LogPrintf("BerkeleyBatch::Rewrite: Rewriting %s...\n", strFile); - std::string strFileRes = strFile + ".rewrite"; - { // surround usage of db with extra {} - BerkeleyBatch db(database, "r"); - std::unique_ptr<Db> pdbCopy = MakeUnique<Db>(env->dbenv.get(), 0); - - int ret = pdbCopy->open(nullptr, // Txn pointer - strFileRes.c_str(), // Filename - "main", // Logical db name - DB_BTREE, // Database type - DB_CREATE, // Flags - 0); - if (ret > 0) { - LogPrintf("BerkeleyBatch::Rewrite: Can't create database file %s\n", strFileRes); - fSuccess = false; - } - - Dbc* pcursor = db.GetCursor(); - if (pcursor) - while (fSuccess) { - CDataStream ssKey(SER_DISK, CLIENT_VERSION); - CDataStream ssValue(SER_DISK, CLIENT_VERSION); - int ret1 = db.ReadAtCursor(pcursor, ssKey, ssValue); - if (ret1 == DB_NOTFOUND) { - pcursor->close(); - break; - } else if (ret1 != 0) { - pcursor->close(); - fSuccess = false; - break; - } - if (pszSkip && - strncmp(ssKey.data(), pszSkip, std::min(ssKey.size(), strlen(pszSkip))) == 0) - continue; - if (strncmp(ssKey.data(), "\x07version", 8) == 0) { - // Update version: - ssValue.clear(); - ssValue << CLIENT_VERSION; - } - Dbt datKey(ssKey.data(), ssKey.size()); - Dbt datValue(ssValue.data(), ssValue.size()); - int ret2 = pdbCopy->put(nullptr, &datKey, &datValue, DB_NOOVERWRITE); - if (ret2 > 0) - fSuccess = false; - } - if (fSuccess) { - db.Close(); - env->CloseDb(strFile); - if (pdbCopy->close(0)) - fSuccess = false; - } else { - pdbCopy->close(0); - } - } - if (fSuccess) { - Db dbA(env->dbenv.get(), 0); - if (dbA.remove(strFile.c_str(), nullptr, 0)) - fSuccess = false; - Db dbB(env->dbenv.get(), 0); - if (dbB.rename(strFileRes.c_str(), nullptr, strFile.c_str(), 0)) - fSuccess = false; - } - if (!fSuccess) - LogPrintf("BerkeleyBatch::Rewrite: Failed to rewrite database file %s\n", strFileRes); - return fSuccess; - } - } - UninterruptibleSleep(std::chrono::milliseconds{100}); - } -} - - -void BerkeleyEnvironment::Flush(bool fShutdown) -{ - int64_t nStart = GetTimeMillis(); - // Flush log data to the actual data file on all files that are not in use - LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: [%s] Flush(%s)%s\n", strPath, fShutdown ? "true" : "false", fDbEnvInit ? "" : " database not started"); - if (!fDbEnvInit) - return; - { - LOCK(cs_db); - std::map<std::string, int>::iterator mi = mapFileUseCount.begin(); - while (mi != mapFileUseCount.end()) { - std::string strFile = (*mi).first; - int nRefCount = (*mi).second; - LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: Flushing %s (refcount = %d)...\n", strFile, nRefCount); - if (nRefCount == 0) { - // Move log data to the dat file - CloseDb(strFile); - LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: %s checkpoint\n", strFile); - dbenv->txn_checkpoint(0, 0, 0); - LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: %s detach\n", strFile); - if (!fMockDb) - dbenv->lsn_reset(strFile.c_str(), 0); - LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: %s closed\n", strFile); - mapFileUseCount.erase(mi++); - } else - mi++; - } - LogPrint(BCLog::WALLETDB, "BerkeleyEnvironment::Flush: Flush(%s)%s took %15dms\n", fShutdown ? "true" : "false", fDbEnvInit ? "" : " database not started", GetTimeMillis() - nStart); - if (fShutdown) { - char** listp; - if (mapFileUseCount.empty()) { - dbenv->log_archive(&listp, DB_ARCH_REMOVE); - Close(); - if (!fMockDb) { - fs::remove_all(fs::path(strPath) / "database"); - } - } - } - } -} - -bool BerkeleyBatch::PeriodicFlush(BerkeleyDatabase& database) -{ - if (database.IsDummy()) { - return true; - } - bool ret = false; - BerkeleyEnvironment *env = database.env.get(); - const std::string& strFile = database.strFile; - TRY_LOCK(cs_db, lockDb); - if (lockDb) - { - // Don't do this if any databases are in use - int nRefCount = 0; - std::map<std::string, int>::iterator mit = env->mapFileUseCount.begin(); - while (mit != env->mapFileUseCount.end()) - { - nRefCount += (*mit).second; - mit++; - } - - if (nRefCount == 0) - { - boost::this_thread::interruption_point(); - std::map<std::string, int>::iterator mi = env->mapFileUseCount.find(strFile); - if (mi != env->mapFileUseCount.end()) - { - LogPrint(BCLog::WALLETDB, "Flushing %s\n", strFile); - int64_t nStart = GetTimeMillis(); - - // Flush wallet file so it's self contained - env->CloseDb(strFile); - env->CheckpointLSN(strFile); - - env->mapFileUseCount.erase(mi++); - LogPrint(BCLog::WALLETDB, "Flushed %s %dms\n", strFile, GetTimeMillis() - nStart); - ret = true; - } - } - } - - return ret; -} - -bool BerkeleyDatabase::Rewrite(const char* pszSkip) -{ - return BerkeleyBatch::Rewrite(*this, pszSkip); -} - -bool BerkeleyDatabase::Backup(const std::string& strDest) const -{ - if (IsDummy()) { - return false; - } - while (true) - { - { - LOCK(cs_db); - if (!env->mapFileUseCount.count(strFile) || env->mapFileUseCount[strFile] == 0) - { - // Flush log data to the dat file - env->CloseDb(strFile); - env->CheckpointLSN(strFile); - env->mapFileUseCount.erase(strFile); - - // Copy wallet file - fs::path pathSrc = env->Directory() / strFile; - fs::path pathDest(strDest); - if (fs::is_directory(pathDest)) - pathDest /= strFile; - - try { - if (fs::equivalent(pathSrc, pathDest)) { - LogPrintf("cannot backup to wallet source file %s\n", pathDest.string()); - return false; - } - - fs::copy_file(pathSrc, pathDest, fs::copy_option::overwrite_if_exists); - LogPrintf("copied %s to %s\n", strFile, pathDest.string()); - return true; - } catch (const fs::filesystem_error& e) { - LogPrintf("error copying %s to %s - %s\n", strFile, pathDest.string(), fsbridge::get_filesystem_error_message(e)); - return false; - } - } - } - UninterruptibleSleep(std::chrono::milliseconds{100}); - } -} - -void BerkeleyDatabase::Flush(bool shutdown) -{ - if (!IsDummy()) { - env->Flush(shutdown); - if (shutdown) { - LOCK(cs_db); - g_dbenvs.erase(env->Directory().string()); - env = nullptr; - } else { - // TODO: To avoid g_dbenvs.erase erasing the environment prematurely after the - // first database shutdown when multiple databases are open in the same - // environment, should replace raw database `env` pointers with shared or weak - // pointers, or else separate the database and environment shutdowns so - // environments can be shut down after databases. - env->m_fileids.erase(strFile); - } - } -} - -void BerkeleyDatabase::ReloadDbEnv() -{ - if (!IsDummy()) { - env->ReloadDbEnv(); - } -} - -std::string BerkeleyDatabaseVersion() -{ - return DbEnv::version(nullptr, nullptr, nullptr); -} diff --git a/src/wallet/db.h b/src/wallet/db.h index 54ce144ffc..0afaba5fd1 100644 --- a/src/wallet/db.h +++ b/src/wallet/db.h @@ -8,387 +8,191 @@ #include <clientversion.h> #include <fs.h> -#include <serialize.h> #include <streams.h> -#include <util/system.h> +#include <util/memory.h> #include <atomic> -#include <map> #include <memory> #include <string> -#include <unordered_map> -#include <vector> - -#if defined(__GNUC__) && !defined(__clang__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wsuggest-override" -#endif -#include <db_cxx.h> -#if defined(__GNUC__) && !defined(__clang__) -#pragma GCC diagnostic pop -#endif struct bilingual_str; -static const unsigned int DEFAULT_WALLET_DBLOGSIZE = 100; -static const bool DEFAULT_WALLET_PRIVDB = true; - -struct WalletDatabaseFileId { - u_int8_t value[DB_FILE_ID_LEN]; - bool operator==(const WalletDatabaseFileId& rhs) const; -}; - -class BerkeleyDatabase; - -class BerkeleyEnvironment -{ -private: - bool fDbEnvInit; - bool fMockDb; - // Don't change into fs::path, as that can result in - // shutdown problems/crashes caused by a static initialized internal pointer. - std::string strPath; - -public: - std::unique_ptr<DbEnv> dbenv; - std::map<std::string, int> mapFileUseCount; - std::map<std::string, std::reference_wrapper<BerkeleyDatabase>> m_databases; - std::unordered_map<std::string, WalletDatabaseFileId> m_fileids; - std::condition_variable_any m_db_in_use; - - BerkeleyEnvironment(const fs::path& env_directory); - BerkeleyEnvironment(); - ~BerkeleyEnvironment(); - void Reset(); - - bool IsMock() const { return fMockDb; } - bool IsInitialized() const { return fDbEnvInit; } - bool IsDatabaseLoaded(const std::string& db_filename) const { return m_databases.find(db_filename) != m_databases.end(); } - fs::path Directory() const { return strPath; } - - bool Verify(const std::string& strFile); - - bool Open(bool retry); - void Close(); - void Flush(bool fShutdown); - void CheckpointLSN(const std::string& strFile); - - void CloseDb(const std::string& strFile); - void ReloadDbEnv(); - - DbTxn* TxnBegin(int flags = DB_TXN_WRITE_NOSYNC) - { - DbTxn* ptxn = nullptr; - int ret = dbenv->txn_begin(nullptr, &ptxn, flags); - if (!ptxn || ret != 0) - return nullptr; - return ptxn; - } -}; - -/** Return whether a wallet database is currently loaded. */ -bool IsWalletLoaded(const fs::path& wallet_path); - /** Given a wallet directory path or legacy file path, return path to main data file in the wallet database. */ fs::path WalletDataFilePath(const fs::path& wallet_path); +void SplitWalletPath(const fs::path& wallet_path, fs::path& env_directory, std::string& database_filename); -/** Get BerkeleyEnvironment and database filename given a wallet path. */ -std::shared_ptr<BerkeleyEnvironment> GetWalletEnv(const fs::path& wallet_path, std::string& database_filename); - -/** An instance of this class represents one database. - * For BerkeleyDB this is just a (env, strFile) tuple. - **/ -class BerkeleyDatabase +/** RAII class that provides access to a WalletDatabase */ +class DatabaseBatch { - friend class BerkeleyBatch; -public: - /** Create dummy DB handle */ - BerkeleyDatabase() : nUpdateCounter(0), nLastSeen(0), nLastFlushed(0), nLastWalletUpdate(0), env(nullptr) - { - } - - /** Create DB handle to real database */ - BerkeleyDatabase(std::shared_ptr<BerkeleyEnvironment> env, std::string filename) : - nUpdateCounter(0), nLastSeen(0), nLastFlushed(0), nLastWalletUpdate(0), env(std::move(env)), strFile(std::move(filename)) - { - auto inserted = this->env->m_databases.emplace(strFile, std::ref(*this)); - assert(inserted.second); - } - - ~BerkeleyDatabase() { - if (env) { - size_t erased = env->m_databases.erase(strFile); - assert(erased == 1); - } - } - - /** Return object for accessing database at specified path. */ - static std::unique_ptr<BerkeleyDatabase> Create(const fs::path& path) - { - std::string filename; - return MakeUnique<BerkeleyDatabase>(GetWalletEnv(path, filename), std::move(filename)); - } - - /** Return object for accessing dummy database with no read/write capabilities. */ - static std::unique_ptr<BerkeleyDatabase> CreateDummy() - { - return MakeUnique<BerkeleyDatabase>(); - } - - /** Return object for accessing temporary in-memory database. */ - static std::unique_ptr<BerkeleyDatabase> CreateMock() - { - return MakeUnique<BerkeleyDatabase>(std::make_shared<BerkeleyEnvironment>(), ""); - } - - /** Rewrite the entire database on disk, with the exception of key pszSkip if non-zero - */ - bool Rewrite(const char* pszSkip=nullptr); - - /** Back up the entire database to a file. - */ - bool Backup(const std::string& strDest) const; - - /** Make sure all changes are flushed to disk. - */ - void Flush(bool shutdown); - - void IncrementUpdateCounter(); - - void ReloadDbEnv(); - - std::atomic<unsigned int> nUpdateCounter; - unsigned int nLastSeen; - unsigned int nLastFlushed; - int64_t nLastWalletUpdate; - - /** - * Pointer to shared database environment. - * - * Normally there is only one BerkeleyDatabase object per - * BerkeleyEnvivonment, but in the special, backwards compatible case where - * multiple wallet BDB data files are loaded from the same directory, this - * will point to a shared instance that gets freed when the last data file - * is closed. - */ - std::shared_ptr<BerkeleyEnvironment> env; - - /** Database pointer. This is initialized lazily and reset during flushes, so it can be null. */ - std::unique_ptr<Db> m_db; - private: - std::string strFile; - - /** Return whether this database handle is a dummy for testing. - * Only to be used at a low level, application should ideally not care - * about this. - */ - bool IsDummy() const { return env == nullptr; } -}; - -/** RAII class that provides access to a Berkeley database */ -class BerkeleyBatch -{ - /** RAII class that automatically cleanses its data on destruction */ - class SafeDbt final - { - Dbt m_dbt; - - public: - // construct Dbt with internally-managed data - SafeDbt(); - // construct Dbt with provided data - SafeDbt(void* data, size_t size); - ~SafeDbt(); - - // delegate to Dbt - const void* get_data() const; - u_int32_t get_size() const; - - // conversion operator to access the underlying Dbt - operator Dbt*(); - }; - -protected: - Db* pdb; - std::string strFile; - DbTxn* activeTxn; - bool fReadOnly; - bool fFlushOnClose; - BerkeleyEnvironment *env; + virtual bool ReadKey(CDataStream&& key, CDataStream& value) = 0; + virtual bool WriteKey(CDataStream&& key, CDataStream&& value, bool overwrite=true) = 0; + virtual bool EraseKey(CDataStream&& key) = 0; + virtual bool HasKey(CDataStream&& key) = 0; public: - explicit BerkeleyBatch(BerkeleyDatabase& database, const char* pszMode = "r+", bool fFlushOnCloseIn=true); - ~BerkeleyBatch() { Close(); } + explicit DatabaseBatch() {} + virtual ~DatabaseBatch() {} - BerkeleyBatch(const BerkeleyBatch&) = delete; - BerkeleyBatch& operator=(const BerkeleyBatch&) = delete; + DatabaseBatch(const DatabaseBatch&) = delete; + DatabaseBatch& operator=(const DatabaseBatch&) = delete; - void Flush(); - void Close(); - - /* flush the wallet passively (TRY_LOCK) - ideal to be called periodically */ - static bool PeriodicFlush(BerkeleyDatabase& database); - /* verifies the database environment */ - static bool VerifyEnvironment(const fs::path& file_path, bilingual_str& errorStr); - /* verifies the database file */ - static bool VerifyDatabaseFile(const fs::path& file_path, bilingual_str& errorStr); + virtual void Flush() = 0; + virtual void Close() = 0; template <typename K, typename T> bool Read(const K& key, T& value) { - if (!pdb) - return false; - - // Key CDataStream ssKey(SER_DISK, CLIENT_VERSION); ssKey.reserve(1000); ssKey << key; - SafeDbt datKey(ssKey.data(), ssKey.size()); - - // Read - SafeDbt datValue; - int ret = pdb->get(activeTxn, datKey, datValue, 0); - bool success = false; - if (datValue.get_data() != nullptr) { - // Unserialize value - try { - CDataStream ssValue((char*)datValue.get_data(), (char*)datValue.get_data() + datValue.get_size(), SER_DISK, CLIENT_VERSION); - ssValue >> value; - success = true; - } catch (const std::exception&) { - // In this case success remains 'false' - } + + CDataStream ssValue(SER_DISK, CLIENT_VERSION); + if (!ReadKey(std::move(ssKey), ssValue)) return false; + try { + ssValue >> value; + return true; + } catch (const std::exception&) { + return false; } - return ret == 0 && success; } template <typename K, typename T> bool Write(const K& key, const T& value, bool fOverwrite = true) { - if (!pdb) - return true; - if (fReadOnly) - assert(!"Write called on database in read-only mode"); - - // Key CDataStream ssKey(SER_DISK, CLIENT_VERSION); ssKey.reserve(1000); ssKey << key; - SafeDbt datKey(ssKey.data(), ssKey.size()); - // Value CDataStream ssValue(SER_DISK, CLIENT_VERSION); ssValue.reserve(10000); ssValue << value; - SafeDbt datValue(ssValue.data(), ssValue.size()); - // Write - int ret = pdb->put(activeTxn, datKey, datValue, (fOverwrite ? 0 : DB_NOOVERWRITE)); - return (ret == 0); + return WriteKey(std::move(ssKey), std::move(ssValue), fOverwrite); } template <typename K> bool Erase(const K& key) { - if (!pdb) - return false; - if (fReadOnly) - assert(!"Erase called on database in read-only mode"); - - // Key CDataStream ssKey(SER_DISK, CLIENT_VERSION); ssKey.reserve(1000); ssKey << key; - SafeDbt datKey(ssKey.data(), ssKey.size()); - // Erase - int ret = pdb->del(activeTxn, datKey, 0); - return (ret == 0 || ret == DB_NOTFOUND); + return EraseKey(std::move(ssKey)); } template <typename K> bool Exists(const K& key) { - if (!pdb) - return false; - - // Key CDataStream ssKey(SER_DISK, CLIENT_VERSION); ssKey.reserve(1000); ssKey << key; - SafeDbt datKey(ssKey.data(), ssKey.size()); - // Exists - int ret = pdb->exists(activeTxn, datKey, 0); - return (ret == 0); + return HasKey(std::move(ssKey)); } - Dbc* GetCursor() - { - if (!pdb) - return nullptr; - Dbc* pcursor = nullptr; - int ret = pdb->cursor(nullptr, &pcursor, 0); - if (ret != 0) - return nullptr; - return pcursor; - } + virtual bool StartCursor() = 0; + virtual bool ReadAtCursor(CDataStream& ssKey, CDataStream& ssValue, bool& complete) = 0; + virtual void CloseCursor() = 0; + virtual bool TxnBegin() = 0; + virtual bool TxnCommit() = 0; + virtual bool TxnAbort() = 0; +}; - int ReadAtCursor(Dbc* pcursor, CDataStream& ssKey, CDataStream& ssValue) - { - // Read at cursor - SafeDbt datKey; - SafeDbt datValue; - int ret = pcursor->get(datKey, datValue, DB_NEXT); - if (ret != 0) - return ret; - else if (datKey.get_data() == nullptr || datValue.get_data() == nullptr) - return 99999; - - // Convert to streams - ssKey.SetType(SER_DISK); - ssKey.clear(); - ssKey.write((char*)datKey.get_data(), datKey.get_size()); - ssValue.SetType(SER_DISK); - ssValue.clear(); - ssValue.write((char*)datValue.get_data(), datValue.get_size()); - return 0; - } +/** An instance of this class represents one database. + **/ +class WalletDatabase +{ +public: + /** Create dummy DB handle */ + WalletDatabase() : nUpdateCounter(0), nLastSeen(0), nLastFlushed(0), nLastWalletUpdate(0) {} + virtual ~WalletDatabase() {}; - bool TxnBegin() - { - if (!pdb || activeTxn) - return false; - DbTxn* ptxn = env->TxnBegin(); - if (!ptxn) - return false; - activeTxn = ptxn; - return true; - } + /** Open the database if it is not already opened. */ + virtual void Open(const char* mode) = 0; - bool TxnCommit() - { - if (!pdb || !activeTxn) - return false; - int ret = activeTxn->commit(0); - activeTxn = nullptr; - return (ret == 0); - } + //! Counts the number of active database users to be sure that the database is not closed while someone is using it + std::atomic<int> m_refcount{0}; + /** Indicate the a new database user has began using the database. Increments m_refcount */ + virtual void AddRef() = 0; + /** Indicate that database user has stopped using the database and that it could be flushed or closed. Decrement m_refcount */ + virtual void RemoveRef() = 0; - bool TxnAbort() - { - if (!pdb || !activeTxn) - return false; - int ret = activeTxn->abort(); - activeTxn = nullptr; - return (ret == 0); - } + /** Rewrite the entire database on disk, with the exception of key pszSkip if non-zero + */ + virtual bool Rewrite(const char* pszSkip=nullptr) = 0; + + /** Back up the entire database to a file. + */ + virtual bool Backup(const std::string& strDest) const = 0; + + /** Make sure all changes are flushed to database file. + */ + virtual void Flush() = 0; + /** Flush to the database file and close the database. + * Also close the environment if no other databases are open in it. + */ + virtual void Close() = 0; + /* flush the wallet passively (TRY_LOCK) + ideal to be called periodically */ + virtual bool PeriodicFlush() = 0; + + virtual void IncrementUpdateCounter() = 0; + + virtual void ReloadDbEnv() = 0; + + std::atomic<unsigned int> nUpdateCounter; + unsigned int nLastSeen; + unsigned int nLastFlushed; + int64_t nLastWalletUpdate; + + /** Verifies the environment and database file */ + virtual bool Verify(bilingual_str& error) = 0; + + std::string m_file_path; + + /** Make a DatabaseBatch connected to this database */ + virtual std::unique_ptr<DatabaseBatch> MakeBatch(const char* mode = "r+", bool flush_on_close = true) = 0; +}; - bool static Rewrite(BerkeleyDatabase& database, const char* pszSkip = nullptr); +/** RAII class that provides access to a DummyDatabase. Never fails. */ +class DummyBatch : public DatabaseBatch +{ +private: + bool ReadKey(CDataStream&& key, CDataStream& value) override { return true; } + bool WriteKey(CDataStream&& key, CDataStream&& value, bool overwrite=true) override { return true; } + bool EraseKey(CDataStream&& key) override { return true; } + bool HasKey(CDataStream&& key) override { return true; } + +public: + void Flush() override {} + void Close() override {} + + bool StartCursor() override { return true; } + bool ReadAtCursor(CDataStream& ssKey, CDataStream& ssValue, bool& complete) override { return true; } + void CloseCursor() override {} + bool TxnBegin() override { return true; } + bool TxnCommit() override { return true; } + bool TxnAbort() override { return true; } }; -std::string BerkeleyDatabaseVersion(); +/** A dummy WalletDatabase that does nothing and never fails. Only used by unit tests. + **/ +class DummyDatabase : public WalletDatabase +{ +public: + void Open(const char* mode) override {}; + void AddRef() override {} + void RemoveRef() override {} + bool Rewrite(const char* pszSkip=nullptr) override { return true; } + bool Backup(const std::string& strDest) const override { return true; } + void Close() override {} + void Flush() override {} + bool PeriodicFlush() override { return true; } + void IncrementUpdateCounter() override { ++nUpdateCounter; } + void ReloadDbEnv() override {} + bool Verify(bilingual_str& errorStr) override { return true; } + std::unique_ptr<DatabaseBatch> MakeBatch(const char* mode = "r+", bool flush_on_close = true) override { return MakeUnique<DummyBatch>(); } +}; #endif // BITCOIN_WALLET_DB_H diff --git a/src/wallet/init.cpp b/src/wallet/init.cpp index 3885eb6185..52162ab521 100644 --- a/src/wallet/init.cpp +++ b/src/wallet/init.cpp @@ -7,8 +7,9 @@ #include <interfaces/chain.h> #include <net.h> #include <node/context.h> +#include <node/ui_interface.h> #include <outputtype.h> -#include <ui_interface.h> +#include <util/check.h> #include <util/moneystr.h> #include <util/system.h> #include <util/translation.h> @@ -16,14 +17,14 @@ #include <wallet/wallet.h> #include <walletinitinterface.h> -class WalletInit : public WalletInitInterface { +class WalletInit : public WalletInitInterface +{ public: - //! Was the wallet component compiled in. bool HasWalletSupport() const override {return true;} //! Return the wallets help message. - void AddWalletOptions() const override; + void AddWalletOptions(ArgsManager& argsman) const override; //! Wallets parameter interaction bool ParameterInteraction() const override; @@ -34,42 +35,42 @@ public: const WalletInitInterface& g_wallet_init_interface = WalletInit(); -void WalletInit::AddWalletOptions() const +void WalletInit::AddWalletOptions(ArgsManager& argsman) const { - gArgs.AddArg("-addresstype", strprintf("What type of addresses to use (\"legacy\", \"p2sh-segwit\", or \"bech32\", default: \"%s\")", FormatOutputType(DEFAULT_ADDRESS_TYPE)), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-avoidpartialspends", strprintf("Group outputs by address, selecting all or none, instead of selecting on a per-output basis. Privacy is improved as an address is only used once (unless someone sends to it after spending from it), but may result in slightly higher fees as suboptimal coin selection may result due to the added limitation (default: %u (always enabled for wallets with \"avoid_reuse\" enabled))", DEFAULT_AVOIDPARTIALSPENDS), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-changetype", "What type of change to use (\"legacy\", \"p2sh-segwit\", or \"bech32\"). Default is same as -addresstype, except when -addresstype=p2sh-segwit a native segwit output is used when sending to a native segwit address)", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-disablewallet", "Do not load the wallet and disable wallet RPC calls", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-discardfee=<amt>", strprintf("The fee rate (in %s/kB) that indicates your tolerance for discarding change by adding it to the fee (default: %s). " + argsman.AddArg("-addresstype", strprintf("What type of addresses to use (\"legacy\", \"p2sh-segwit\", or \"bech32\", default: \"%s\")", FormatOutputType(DEFAULT_ADDRESS_TYPE)), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-avoidpartialspends", strprintf("Group outputs by address, selecting all or none, instead of selecting on a per-output basis. Privacy is improved as an address is only used once (unless someone sends to it after spending from it), but may result in slightly higher fees as suboptimal coin selection may result due to the added limitation (default: %u (always enabled for wallets with \"avoid_reuse\" enabled))", DEFAULT_AVOIDPARTIALSPENDS), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-changetype", "What type of change to use (\"legacy\", \"p2sh-segwit\", or \"bech32\"). Default is same as -addresstype, except when -addresstype=p2sh-segwit a native segwit output is used when sending to a native segwit address)", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-disablewallet", "Do not load the wallet and disable wallet RPC calls", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-discardfee=<amt>", strprintf("The fee rate (in %s/kB) that indicates your tolerance for discarding change by adding it to the fee (default: %s). " "Note: An output is discarded if it is dust at this rate, but we will always discard up to the dust relay fee and a discard fee above that is limited by the fee estimate for the longest target", CURRENCY_UNIT, FormatMoney(DEFAULT_DISCARD_FEE)), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-fallbackfee=<amt>", strprintf("A fee rate (in %s/kB) that will be used when fee estimation has insufficient data. 0 to entirely disable the fallbackfee feature. (default: %s)", + argsman.AddArg("-fallbackfee=<amt>", strprintf("A fee rate (in %s/kB) that will be used when fee estimation has insufficient data. 0 to entirely disable the fallbackfee feature. (default: %s)", CURRENCY_UNIT, FormatMoney(DEFAULT_FALLBACK_FEE)), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-keypool=<n>", strprintf("Set key pool size to <n> (default: %u). Warning: Smaller sizes may increase the risk of losing funds when restoring from an old backup, if none of the addresses in the original keypool have been used.", DEFAULT_KEYPOOL_SIZE), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-maxtxfee=<amt>", strprintf("Maximum total fees (in %s) to use in a single wallet transaction; setting this too low may abort large transactions (default: %s)", + argsman.AddArg("-keypool=<n>", strprintf("Set key pool size to <n> (default: %u). Warning: Smaller sizes may increase the risk of losing funds when restoring from an old backup, if none of the addresses in the original keypool have been used.", DEFAULT_KEYPOOL_SIZE), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-maxtxfee=<amt>", strprintf("Maximum total fees (in %s) to use in a single wallet transaction; setting this too low may abort large transactions (default: %s)", CURRENCY_UNIT, FormatMoney(DEFAULT_TRANSACTION_MAXFEE)), ArgsManager::ALLOW_ANY, OptionsCategory::DEBUG_TEST); - gArgs.AddArg("-mintxfee=<amt>", strprintf("Fees (in %s/kB) smaller than this are considered zero fee for transaction creation (default: %s)", + argsman.AddArg("-mintxfee=<amt>", strprintf("Fees (in %s/kB) smaller than this are considered zero fee for transaction creation (default: %s)", CURRENCY_UNIT, FormatMoney(DEFAULT_TRANSACTION_MINFEE)), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-paytxfee=<amt>", strprintf("Fee (in %s/kB) to add to transactions you send (default: %s)", + argsman.AddArg("-paytxfee=<amt>", strprintf("Fee (in %s/kB) to add to transactions you send (default: %s)", CURRENCY_UNIT, FormatMoney(CFeeRate{DEFAULT_PAY_TX_FEE}.GetFeePerK())), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-rescan", "Rescan the block chain for missing wallet transactions on startup", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-spendzeroconfchange", strprintf("Spend unconfirmed change when sending transactions (default: %u)", DEFAULT_SPEND_ZEROCONF_CHANGE), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-txconfirmtarget=<n>", strprintf("If paytxfee is not set, include enough fee so transactions begin confirmation on average within n blocks (default: %u)", DEFAULT_TX_CONFIRM_TARGET), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-wallet=<path>", "Specify wallet database path. Can be specified multiple times to load multiple wallets. Path is interpreted relative to <walletdir> if it is not absolute, and will be created if it does not exist (as a directory containing a wallet.dat file and log files). For backwards compatibility this will also accept names of existing data files in <walletdir>.)", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::WALLET); - gArgs.AddArg("-walletbroadcast", strprintf("Make the wallet broadcast transactions (default: %u)", DEFAULT_WALLETBROADCAST), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-walletdir=<dir>", "Specify directory to hold wallets (default: <datadir>/wallets if it exists, otherwise <datadir>)", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::WALLET); + argsman.AddArg("-rescan", "Rescan the block chain for missing wallet transactions on startup", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-spendzeroconfchange", strprintf("Spend unconfirmed change when sending transactions (default: %u)", DEFAULT_SPEND_ZEROCONF_CHANGE), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-txconfirmtarget=<n>", strprintf("If paytxfee is not set, include enough fee so transactions begin confirmation on average within n blocks (default: %u)", DEFAULT_TX_CONFIRM_TARGET), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-wallet=<path>", "Specify wallet database path. Can be specified multiple times to load multiple wallets. Path is interpreted relative to <walletdir> if it is not absolute, and will be created if it does not exist (as a directory containing a wallet.dat file and log files). For backwards compatibility this will also accept names of existing data files in <walletdir>.)", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::WALLET); + argsman.AddArg("-walletbroadcast", strprintf("Make the wallet broadcast transactions (default: %u)", DEFAULT_WALLETBROADCAST), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-walletdir=<dir>", "Specify directory to hold wallets (default: <datadir>/wallets if it exists, otherwise <datadir>)", ArgsManager::ALLOW_ANY | ArgsManager::NETWORK_ONLY, OptionsCategory::WALLET); #if HAVE_SYSTEM - gArgs.AddArg("-walletnotify=<cmd>", "Execute command when a wallet transaction changes. %s in cmd is replaced by TxID and %w is replaced by wallet name. %w is not currently implemented on windows. On systems where %w is supported, it should NOT be quoted because this would break shell escaping used to invoke the command.", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-walletnotify=<cmd>", "Execute command when a wallet transaction changes. %s in cmd is replaced by TxID and %w is replaced by wallet name. %w is not currently implemented on windows. On systems where %w is supported, it should NOT be quoted because this would break shell escaping used to invoke the command.", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); #endif - gArgs.AddArg("-walletrbf", strprintf("Send transactions with full-RBF opt-in enabled (RPC only, default: %u)", DEFAULT_WALLET_RBF), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-zapwallettxes=<mode>", "Delete all wallet transactions and only recover those parts of the blockchain through -rescan on startup" + argsman.AddArg("-walletrbf", strprintf("Send transactions with full-RBF opt-in enabled (RPC only, default: %u)", DEFAULT_WALLET_RBF), ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); + argsman.AddArg("-zapwallettxes=<mode>", "Delete all wallet transactions and only recover those parts of the blockchain through -rescan on startup" " (1 = keep tx meta data e.g. payment request information, 2 = drop tx meta data)", ArgsManager::ALLOW_ANY, OptionsCategory::WALLET); - gArgs.AddArg("-dblogsize=<n>", strprintf("Flush wallet database activity from memory to disk log every <n> megabytes (default: %u)", DEFAULT_WALLET_DBLOGSIZE), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::WALLET_DEBUG_TEST); - gArgs.AddArg("-flushwallet", strprintf("Run a thread to flush wallet periodically (default: %u)", DEFAULT_FLUSHWALLET), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::WALLET_DEBUG_TEST); - gArgs.AddArg("-privdb", strprintf("Sets the DB_PRIVATE flag in the wallet db environment (default: %u)", DEFAULT_WALLET_PRIVDB), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::WALLET_DEBUG_TEST); - gArgs.AddArg("-walletrejectlongchains", strprintf("Wallet will not create transactions that violate mempool chain limits (default: %u)", DEFAULT_WALLET_REJECT_LONG_CHAINS), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::WALLET_DEBUG_TEST); + argsman.AddArg("-dblogsize=<n>", strprintf("Flush wallet database activity from memory to disk log every <n> megabytes (default: %u)", DEFAULT_WALLET_DBLOGSIZE), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::WALLET_DEBUG_TEST); + argsman.AddArg("-flushwallet", strprintf("Run a thread to flush wallet periodically (default: %u)", DEFAULT_FLUSHWALLET), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::WALLET_DEBUG_TEST); + argsman.AddArg("-privdb", strprintf("Sets the DB_PRIVATE flag in the wallet db environment (default: %u)", DEFAULT_WALLET_PRIVDB), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::WALLET_DEBUG_TEST); + argsman.AddArg("-walletrejectlongchains", strprintf("Wallet will not create transactions that violate mempool chain limits (default: %u)", DEFAULT_WALLET_REJECT_LONG_CHAINS), ArgsManager::ALLOW_ANY | ArgsManager::DEBUG_ONLY, OptionsCategory::WALLET_DEBUG_TEST); } bool WalletInit::ParameterInteraction() const @@ -112,10 +113,11 @@ bool WalletInit::ParameterInteraction() const void WalletInit::Construct(NodeContext& node) const { - if (gArgs.GetBoolArg("-disablewallet", DEFAULT_DISABLE_WALLET)) { + ArgsManager& args = *Assert(node.args); + if (args.GetBoolArg("-disablewallet", DEFAULT_DISABLE_WALLET)) { LogPrintf("Wallet disabled!\n"); return; } - gArgs.SoftSetArg("-wallet", ""); - node.chain_clients.emplace_back(interfaces::MakeWalletClient(*node.chain, gArgs.GetArgs("-wallet"))); + args.SoftSetArg("-wallet", ""); + node.chain_clients.emplace_back(interfaces::MakeWalletClient(*node.chain, args, args.GetArgs("-wallet"))); } diff --git a/src/wallet/load.cpp b/src/wallet/load.cpp index 8df3e78215..2a81d30133 100644 --- a/src/wallet/load.cpp +++ b/src/wallet/load.cpp @@ -11,6 +11,7 @@ #include <util/system.h> #include <util/translation.h> #include <wallet/wallet.h> +#include <wallet/walletdb.h> bool VerifyWallets(interfaces::Chain& chain, const std::vector<std::string>& wallet_files) { @@ -82,28 +83,30 @@ bool LoadWallets(interfaces::Chain& chain, const std::vector<std::string>& walle } } -void StartWallets(CScheduler& scheduler) +void StartWallets(CScheduler& scheduler, const ArgsManager& args) { for (const std::shared_ptr<CWallet>& pwallet : GetWallets()) { pwallet->postInitProcess(); } // Schedule periodic wallet flushes and tx rebroadcasts - scheduler.scheduleEvery(MaybeCompactWalletDB, std::chrono::milliseconds{500}); + if (args.GetBoolArg("-flushwallet", DEFAULT_FLUSHWALLET)) { + scheduler.scheduleEvery(MaybeCompactWalletDB, std::chrono::milliseconds{500}); + } scheduler.scheduleEvery(MaybeResendWalletTxs, std::chrono::milliseconds{1000}); } void FlushWallets() { for (const std::shared_ptr<CWallet>& pwallet : GetWallets()) { - pwallet->Flush(false); + pwallet->Flush(); } } void StopWallets() { for (const std::shared_ptr<CWallet>& pwallet : GetWallets()) { - pwallet->Flush(true); + pwallet->Close(); } } diff --git a/src/wallet/load.h b/src/wallet/load.h index e24b1f2e69..ff4f5b4b23 100644 --- a/src/wallet/load.h +++ b/src/wallet/load.h @@ -9,6 +9,7 @@ #include <string> #include <vector> +class ArgsManager; class CScheduler; namespace interfaces { @@ -22,7 +23,7 @@ bool VerifyWallets(interfaces::Chain& chain, const std::vector<std::string>& wal bool LoadWallets(interfaces::Chain& chain, const std::vector<std::string>& wallet_files); //! Complete startup of wallets. -void StartWallets(CScheduler& scheduler); +void StartWallets(CScheduler& scheduler, const ArgsManager& args); //! Flush all wallets in preparation for shutdown. void FlushWallets(); diff --git a/src/wallet/rpcdump.cpp b/src/wallet/rpcdump.cpp index d5f6d63a46..e0c3a1287a 100644 --- a/src/wallet/rpcdump.cpp +++ b/src/wallet/rpcdump.cpp @@ -34,7 +34,7 @@ std::string static EncodeDumpString(const std::string &str) { std::stringstream ret; for (const unsigned char c : str) { if (c <= 32 || c >= 128 || c == '%') { - ret << '%' << HexStr(&c, &c + 1); + ret << '%' << HexStr(Span<const unsigned char>(&c, 1)); } else { ret << c; } @@ -92,12 +92,6 @@ static void RescanWallet(CWallet& wallet, const WalletRescanReserver& reserver, UniValue importprivkey(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"importprivkey", "\nAdds a private key (as returned by dumpprivkey) to your wallet. Requires a new wallet backup.\n" "Hint: use importmulti to import more than one private key.\n" @@ -124,6 +118,10 @@ UniValue importprivkey(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + if (pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)) { throw JSONRPCError(RPC_WALLET_ERROR, "Cannot import private keys to a wallet with private keys disabled"); } @@ -195,12 +193,6 @@ UniValue importprivkey(const JSONRPCRequest& request) UniValue abortrescan(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"abortrescan", "\nStops current wallet rescan triggered by an RPC call, e.g. by an importprivkey call.\n" "Note: Use \"getwalletinfo\" to query the scanning progress.\n", @@ -216,6 +208,10 @@ UniValue abortrescan(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + if (!pwallet->IsScanning() || pwallet->IsAbortingRescan()) return false; pwallet->AbortRescan(); return true; @@ -223,12 +219,6 @@ UniValue abortrescan(const JSONRPCRequest& request) UniValue importaddress(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"importaddress", "\nAdds an address or script (in hex) that can be watched as if it were in your wallet but cannot be used to spend. Requires a new wallet backup.\n" "\nNote: This call can take over an hour to complete if rescan is true, during that time, other rpc calls\n" @@ -255,6 +245,10 @@ UniValue importaddress(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + EnsureLegacyScriptPubKeyMan(*pwallet, true); std::string strLabel; @@ -303,7 +297,7 @@ UniValue importaddress(const JSONRPCRequest& request) pwallet->ImportScripts(scripts, 0 /* timestamp */); if (fP2SH) { - scripts.insert(GetScriptForDestination(ScriptHash(CScriptID(redeem_script)))); + scripts.insert(GetScriptForDestination(ScriptHash(redeem_script))); } pwallet->ImportScriptPubKeys(strLabel, scripts, false /* have_solving_data */, true /* apply_label */, 1 /* timestamp */); @@ -325,12 +319,6 @@ UniValue importaddress(const JSONRPCRequest& request) UniValue importprunedfunds(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"importprunedfunds", "\nImports funds without rescan. Corresponding address or script must previously be included in wallet. Aimed towards pruned wallets. The end-user is responsible to import additional transactions that subsequently spend the imported outputs or rescan after the point in the blockchain the transaction is included.\n", { @@ -341,6 +329,10 @@ UniValue importprunedfunds(const JSONRPCRequest& request) RPCExamples{""}, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + CMutableTransaction tx; if (!DecodeHexTx(tx, request.params[0].get_str())) throw JSONRPCError(RPC_DESERIALIZATION_ERROR, "TX decode failed"); @@ -383,12 +375,6 @@ UniValue importprunedfunds(const JSONRPCRequest& request) UniValue removeprunedfunds(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"removeprunedfunds", "\nDeletes the specified transaction from the wallet. Meant for use with pruned wallets and as a companion to importprunedfunds. This will affect wallet balances.\n", { @@ -402,6 +388,10 @@ UniValue removeprunedfunds(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); uint256 hash(ParseHashV(request.params[0], "txid")); @@ -422,12 +412,6 @@ UniValue removeprunedfunds(const JSONRPCRequest& request) UniValue importpubkey(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"importpubkey", "\nAdds a public key (in hex) that can be watched as if it were in your wallet but cannot be used to spend. Requires a new wallet backup.\n" "Hint: use importmulti to import more than one public key.\n" @@ -450,6 +434,10 @@ UniValue importpubkey(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + EnsureLegacyScriptPubKeyMan(*wallet, true); std::string strLabel; @@ -509,12 +497,6 @@ UniValue importpubkey(const JSONRPCRequest& request) UniValue importwallet(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"importwallet", "\nImports keys from a wallet dump file (see dumpwallet). Requires a new wallet backup to include imported keys.\n" "Note: Use \"getwalletinfo\" to query the scanning progress.\n", @@ -532,6 +514,10 @@ UniValue importwallet(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + EnsureLegacyScriptPubKeyMan(*wallet, true); if (pwallet->chain().havePruned()) { @@ -667,12 +653,6 @@ UniValue importwallet(const JSONRPCRequest& request) UniValue dumpprivkey(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"dumpprivkey", "\nReveals the private key corresponding to 'address'.\n" "Then the importprivkey can be used with this output\n", @@ -689,6 +669,10 @@ UniValue dumpprivkey(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + LegacyScriptPubKeyMan& spk_man = EnsureLegacyScriptPubKeyMan(*wallet); LOCK2(pwallet->cs_wallet, spk_man.cs_KeyStore); @@ -714,11 +698,6 @@ UniValue dumpprivkey(const JSONRPCRequest& request) UniValue dumpwallet(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const pwallet = GetWalletForJSONRPCRequest(request); - if (!EnsureWalletIsAvailable(pwallet.get(), request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"dumpwallet", "\nDumps all wallet keys in a human-readable format to a server-side file. This does not allow overwriting existing files.\n" "Imported scripts are included in the dumpfile, but corresponding BIP173 addresses, etc. may not be added automatically by importwallet.\n" @@ -739,6 +718,9 @@ UniValue dumpwallet(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const pwallet = GetWalletForJSONRPCRequest(request); + if (!pwallet) return NullUniValue; + CWallet& wallet = *pwallet; LegacyScriptPubKeyMan& spk_man = EnsureLegacyScriptPubKeyMan(wallet); @@ -835,7 +817,7 @@ UniValue dumpwallet(const JSONRPCRequest& request) create_time = FormatISO8601DateTime(it->second.nCreateTime); } if(spk_man.GetCScript(scriptid, script)) { - file << strprintf("%s %s script=1", HexStr(script.begin(), script.end()), create_time); + file << strprintf("%s %s script=1", HexStr(script), create_time); file << strprintf(" # addr=%s\n", address); } } @@ -874,20 +856,20 @@ static std::string RecurseImportData(const CScript& script, ImportData& import_d { // Use Solver to obtain script type and parsed pubkeys or hashes: std::vector<std::vector<unsigned char>> solverdata; - txnouttype script_type = Solver(script, solverdata); + TxoutType script_type = Solver(script, solverdata); switch (script_type) { - case TX_PUBKEY: { + case TxoutType::PUBKEY: { CPubKey pubkey(solverdata[0].begin(), solverdata[0].end()); import_data.used_keys.emplace(pubkey.GetID(), false); return ""; } - case TX_PUBKEYHASH: { + case TxoutType::PUBKEYHASH: { CKeyID id = CKeyID(uint160(solverdata[0])); import_data.used_keys[id] = true; return ""; } - case TX_SCRIPTHASH: { + case TxoutType::SCRIPTHASH: { if (script_ctx == ScriptContext::P2SH) throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Trying to nest P2SH inside another P2SH"); if (script_ctx == ScriptContext::WITNESS_V0) throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Trying to nest P2SH inside a P2WSH"); CHECK_NONFATAL(script_ctx == ScriptContext::TOP); @@ -898,14 +880,14 @@ static std::string RecurseImportData(const CScript& script, ImportData& import_d import_data.import_scripts.emplace(*subscript); return RecurseImportData(*subscript, import_data, ScriptContext::P2SH); } - case TX_MULTISIG: { + case TxoutType::MULTISIG: { for (size_t i = 1; i + 1< solverdata.size(); ++i) { CPubKey pubkey(solverdata[i].begin(), solverdata[i].end()); import_data.used_keys.emplace(pubkey.GetID(), false); } return ""; } - case TX_WITNESS_V0_SCRIPTHASH: { + case TxoutType::WITNESS_V0_SCRIPTHASH: { if (script_ctx == ScriptContext::WITNESS_V0) throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Trying to nest P2WSH inside another P2WSH"); uint256 fullid(solverdata[0]); CScriptID id; @@ -919,7 +901,7 @@ static std::string RecurseImportData(const CScript& script, ImportData& import_d import_data.import_scripts.emplace(*subscript); return RecurseImportData(*subscript, import_data, ScriptContext::WITNESS_V0); } - case TX_WITNESS_V0_KEYHASH: { + case TxoutType::WITNESS_V0_KEYHASH: { if (script_ctx == ScriptContext::WITNESS_V0) throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Trying to nest P2WPKH inside P2WSH"); CKeyID id = CKeyID(uint160(solverdata[0])); import_data.used_keys[id] = true; @@ -928,10 +910,10 @@ static std::string RecurseImportData(const CScript& script, ImportData& import_d } return ""; } - case TX_NULL_DATA: + case TxoutType::NULL_DATA: return "unspendable script"; - case TX_NONSTANDARD: - case TX_WITNESS_UNKNOWN: + case TxoutType::NONSTANDARD: + case TxoutType::WITNESS_UNKNOWN: default: return "unrecognized script"; } @@ -1211,7 +1193,7 @@ static UniValue ProcessImport(CWallet * const pwallet, const UniValue& data, con // Check whether we have any work to do for (const CScript& script : script_pub_keys) { if (pwallet->IsMine(script) & ISMINE_SPENDABLE) { - throw JSONRPCError(RPC_WALLET_ERROR, "The wallet already contains the private key for this address or script (\"" + HexStr(script.begin(), script.end()) + "\")"); + throw JSONRPCError(RPC_WALLET_ERROR, "The wallet already contains the private key for this address or script (\"" + HexStr(script) + "\")"); } } @@ -1259,12 +1241,6 @@ static int64_t GetImportTimestamp(const UniValue& data, int64_t now) UniValue importmulti(const JSONRPCRequest& mainRequest) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(mainRequest); - CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, mainRequest.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"importmulti", "\nImport addresses/scripts (with private or public keys, redeem script (P2SH)), optionally rescanning the blockchain from the earliest creation time of the imported scripts. Requires a new wallet backup.\n" "If an address/script is imported without all of the private keys required to spend from that address, it will be watchonly. The 'watchonly' option must be set to true in this case or a warning will be returned.\n" @@ -1340,6 +1316,9 @@ UniValue importmulti(const JSONRPCRequest& mainRequest) }, }.Check(mainRequest); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(mainRequest); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); RPCTypeCheck(mainRequest.params, {UniValue::VARR, UniValue::VOBJ}); @@ -1568,7 +1547,7 @@ static UniValue ProcessDescriptorImport(CWallet * const pwallet, const UniValue& if (!w_desc.descriptor->GetOutputType()) { warnings.push_back("Unknown output type, cannot set descriptor to active."); } else { - pwallet->SetActiveScriptPubKeyMan(spk_manager->GetID(), *w_desc.descriptor->GetOutputType(), internal); + pwallet->AddActiveScriptPubKeyMan(spk_manager->GetID(), *w_desc.descriptor->GetOutputType(), internal); } } @@ -1585,14 +1564,8 @@ static UniValue ProcessDescriptorImport(CWallet * const pwallet, const UniValue& return result; } -UniValue importdescriptors(const JSONRPCRequest& main_request) { - // Acquire the wallet - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(main_request); - CWallet* const pwallet = wallet.get(); - if (!EnsureWalletIsAvailable(pwallet, main_request.fHelp)) { - return NullUniValue; - } - +UniValue importdescriptors(const JSONRPCRequest& main_request) +{ RPCHelpMan{"importdescriptors", "\nImport descriptors. This will trigger a rescan of the blockchain based on the earliest timestamp of all descriptors being imported. Requires a new wallet backup.\n" "\nNote: This call can take over an hour to complete if using an early timestamp; during that time, other rpc calls\n" @@ -1644,6 +1617,10 @@ UniValue importdescriptors(const JSONRPCRequest& main_request) { }, }.Check(main_request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(main_request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + // Make sure wallet is a descriptor wallet if (!pwallet->IsWalletFlagSet(WALLET_FLAG_DESCRIPTORS)) { throw JSONRPCError(RPC_WALLET_ERROR, "importdescriptors is not available for non-descriptor wallets"); diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp index 2a9ac189ea..7956533766 100644 --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -21,12 +21,14 @@ #include <util/fees.h> #include <util/message.h> // For MessageSign() #include <util/moneystr.h> +#include <util/ref.h> #include <util/string.h> #include <util/system.h> #include <util/translation.h> #include <util/url.h> #include <util/vector.h> #include <wallet/coincontrol.h> +#include <wallet/context.h> #include <wallet/feebumper.h> #include <wallet/rpcwallet.h> #include <wallet/wallet.h> @@ -43,6 +45,8 @@ using interfaces::FoundBlock; static const std::string WALLET_ENDPOINT_BASE = "/wallet/"; static const std::string HELP_REQUIRING_PASSPHRASE{"\nRequires wallet passphrase to be set with walletpassphrase call if wallet is encrypted.\n"}; +static const uint32_t WALLET_BTC_KB_TO_SAT_B = COIN / 1000; // 1 sat / B = 0.00001 BTC / kB + static inline bool GetAvoidReuseFlag(const CWallet* const pwallet, const UniValue& param) { bool can_avoid_reuse = pwallet->IsWalletFlagSet(WALLET_FLAG_AVOID_REUSE); bool avoid_reuse = param.isNull() ? can_avoid_reuse : param.get_bool(); @@ -91,6 +95,7 @@ bool GetWalletNameFromJSONRPCRequest(const JSONRPCRequest& request, std::string& std::shared_ptr<CWallet> GetWalletForJSONRPCRequest(const JSONRPCRequest& request) { + CHECK_NONFATAL(!request.fHelp); std::string wallet_name; if (GetWalletNameFromJSONRPCRequest(request, wallet_name)) { std::shared_ptr<CWallet> pwallet = GetWallet(wallet_name); @@ -99,14 +104,11 @@ std::shared_ptr<CWallet> GetWalletForJSONRPCRequest(const JSONRPCRequest& reques } std::vector<std::shared_ptr<CWallet>> wallets = GetWallets(); - return wallets.size() == 1 || (request.fHelp && wallets.size() > 0) ? wallets[0] : nullptr; -} + if (wallets.size() == 1) { + return wallets[0]; + } -bool EnsureWalletIsAvailable(const CWallet* pwallet, bool avoidException) -{ - if (pwallet) return true; - if (avoidException) return false; - if (!HasWallets()) { + if (wallets.empty()) { throw JSONRPCError( RPC_METHOD_NOT_FOUND, "Method not found (wallet method is disabled because no wallet is loaded)"); } @@ -121,6 +123,14 @@ void EnsureWalletIsUnlocked(const CWallet* pwallet) } } +WalletContext& EnsureWalletContext(const util::Ref& context) +{ + if (!context.Has<WalletContext>()) { + throw JSONRPCError(RPC_INTERNAL_ERROR, "Wallet context not found"); + } + return context.Get<WalletContext>(); +} + // also_create should only be set to true only when the RPC is expected to add things to a blank wallet and make it no longer blank LegacyScriptPubKeyMan& EnsureLegacyScriptPubKeyMan(CWallet& wallet, bool also_create) { @@ -183,15 +193,44 @@ static std::string LabelFromValue(const UniValue& value) return label; } -static UniValue getnewaddress(const JSONRPCRequest& request) +/** + * Update coin control with fee estimation based on the given parameters + * + * @param[in] pwallet Wallet pointer + * @param[in,out] cc Coin control which is to be updated + * @param[in] estimate_mode String value (e.g. "ECONOMICAL") + * @param[in] estimate_param Parameter (blocks to confirm, explicit fee rate, etc) + * @throws a JSONRPCError if estimate_mode is unknown, or if estimate_param is missing when required + */ +static void SetFeeEstimateMode(const CWallet* pwallet, CCoinControl& cc, const UniValue& estimate_mode, const UniValue& estimate_param) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); + if (!estimate_mode.isNull()) { + if (!FeeModeFromString(estimate_mode.get_str(), cc.m_fee_mode)) { + throw JSONRPCError(RPC_INVALID_PARAMETER, "Invalid estimate_mode parameter"); + } + } + + if (cc.m_fee_mode == FeeEstimateMode::BTC_KB || cc.m_fee_mode == FeeEstimateMode::SAT_B) { + if (estimate_param.isNull()) { + throw JSONRPCError(RPC_INVALID_PARAMETER, "Selected estimate_mode requires a fee rate"); + } + + CAmount fee_rate = AmountFromValue(estimate_param); + if (cc.m_fee_mode == FeeEstimateMode::SAT_B) { + fee_rate /= WALLET_BTC_KB_TO_SAT_B; + } + + cc.m_feerate = CFeeRate(fee_rate); - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; + // default RBF to true for explicit fee rate modes + if (cc.m_signal_bip125_rbf == nullopt) cc.m_signal_bip125_rbf = true; + } else if (!estimate_param.isNull()) { + cc.m_confirm_target = ParseConfirmTarget(estimate_param, pwallet->chain().estimateMaxBlocks()); } +} +static UniValue getnewaddress(const JSONRPCRequest& request) +{ RPCHelpMan{"getnewaddress", "\nReturns a new Bitcoin address for receiving payments.\n" "If 'label' is specified, it is added to the address book \n" @@ -209,6 +248,10 @@ static UniValue getnewaddress(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); if (!pwallet->CanGetAddresses()) { @@ -238,13 +281,6 @@ static UniValue getnewaddress(const JSONRPCRequest& request) static UniValue getrawchangeaddress(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"getrawchangeaddress", "\nReturns a new Bitcoin address, for receiving change.\n" "This is for use with raw transactions, NOT normal use.\n", @@ -260,13 +296,17 @@ static UniValue getrawchangeaddress(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); if (!pwallet->CanGetAddresses(true)) { throw JSONRPCError(RPC_WALLET_ERROR, "Error: This wallet has no available keys"); } - OutputType output_type = pwallet->m_default_change_type != OutputType::CHANGE_AUTO ? pwallet->m_default_change_type : pwallet->m_default_address_type; + OutputType output_type = pwallet->m_default_change_type.get_value_or(pwallet->m_default_address_type); if (!request.params[0].isNull()) { if (!ParseOutputType(request.params[0].get_str(), output_type)) { throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, strprintf("Unknown address type '%s'", request.params[0].get_str())); @@ -284,13 +324,6 @@ static UniValue getrawchangeaddress(const JSONRPCRequest& request) static UniValue setlabel(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"setlabel", "\nSets the label associated with the given address.\n", { @@ -304,6 +337,10 @@ static UniValue setlabel(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); CTxDestination dest = DecodeDestination(request.params[0].get_str()); @@ -322,47 +359,58 @@ static UniValue setlabel(const JSONRPCRequest& request) return NullUniValue; } +void ParseRecipients(const UniValue& address_amounts, const UniValue& subtract_fee_outputs, std::vector<CRecipient> &recipients) { + std::set<CTxDestination> destinations; + int i = 0; + for (const std::string& address: address_amounts.getKeys()) { + CTxDestination dest = DecodeDestination(address); + if (!IsValidDestination(dest)) { + throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, std::string("Invalid Bitcoin address: ") + address); + } -static CTransactionRef SendMoney(CWallet* const pwallet, const CTxDestination& address, CAmount nValue, bool fSubtractFeeFromAmount, const CCoinControl& coin_control, mapValue_t mapValue) -{ - CAmount curBalance = pwallet->GetBalance(0, coin_control.m_avoid_address_reuse).m_mine_trusted; + if (destinations.count(dest)) { + throw JSONRPCError(RPC_INVALID_PARAMETER, std::string("Invalid parameter, duplicated address: ") + address); + } + destinations.insert(dest); - // Check amount - if (nValue <= 0) - throw JSONRPCError(RPC_INVALID_PARAMETER, "Invalid amount"); + CScript script_pub_key = GetScriptForDestination(dest); + CAmount amount = AmountFromValue(address_amounts[i++]); - if (nValue > curBalance) - throw JSONRPCError(RPC_WALLET_INSUFFICIENT_FUNDS, "Insufficient funds"); + bool subtract_fee = false; + for (unsigned int idx = 0; idx < subtract_fee_outputs.size(); idx++) { + const UniValue& addr = subtract_fee_outputs[idx]; + if (addr.get_str() == address) { + subtract_fee = true; + } + } - // Parse Bitcoin address - CScript scriptPubKey = GetScriptForDestination(address); + CRecipient recipient = {script_pub_key, amount, subtract_fee}; + recipients.push_back(recipient); + } +} + +UniValue SendMoney(CWallet* const pwallet, const CCoinControl &coin_control, std::vector<CRecipient> &recipients, mapValue_t map_value) +{ + EnsureWalletIsUnlocked(pwallet); - // Create and send the transaction + // Shuffle recipient list + std::shuffle(recipients.begin(), recipients.end(), FastRandomContext()); + + // Send CAmount nFeeRequired = 0; - bilingual_str error; - std::vector<CRecipient> vecSend; int nChangePosRet = -1; - CRecipient recipient = {scriptPubKey, nValue, fSubtractFeeFromAmount}; - vecSend.push_back(recipient); + bilingual_str error; CTransactionRef tx; - if (!pwallet->CreateTransaction(vecSend, tx, nFeeRequired, nChangePosRet, error, coin_control)) { - if (!fSubtractFeeFromAmount && nValue + nFeeRequired > curBalance) - error = strprintf(Untranslated("Error: This transaction requires a transaction fee of at least %s"), FormatMoney(nFeeRequired)); - throw JSONRPCError(RPC_WALLET_ERROR, error.original); + bool fCreated = pwallet->CreateTransaction(recipients, tx, nFeeRequired, nChangePosRet, error, coin_control, !pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)); + if (!fCreated) { + throw JSONRPCError(RPC_WALLET_INSUFFICIENT_FUNDS, error.original); } - pwallet->CommitTransaction(tx, std::move(mapValue), {} /* orderForm */); - return tx; + pwallet->CommitTransaction(tx, std::move(map_value), {} /* orderForm */); + return tx->GetHash().GetHex(); } static UniValue sendtoaddress(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"sendtoaddress", "\nSend an amount to a given address." + HELP_REQUIRING_PASSPHRASE, @@ -377,11 +425,9 @@ static UniValue sendtoaddress(const JSONRPCRequest& request) {"subtractfeefromamount", RPCArg::Type::BOOL, /* default */ "false", "The fee will be deducted from the amount being sent.\n" " The recipient will receive less bitcoins than you enter in the amount field."}, {"replaceable", RPCArg::Type::BOOL, /* default */ "wallet default", "Allow this transaction to be replaced by a transaction with higher fees via BIP 125"}, - {"conf_target", RPCArg::Type::NUM, /* default */ "wallet default", "Confirmation target (in blocks)"}, - {"estimate_mode", RPCArg::Type::STR, /* default */ "UNSET", "The fee estimate mode, must be one of:\n" - " \"UNSET\"\n" - " \"ECONOMICAL\"\n" - " \"CONSERVATIVE\""}, + {"conf_target", RPCArg::Type::NUM, /* default */ "wallet default", "Confirmation target (in blocks), or fee rate (for " + CURRENCY_UNIT + "/kB or " + CURRENCY_ATOM + "/B estimate modes)"}, + {"estimate_mode", RPCArg::Type::STR, /* default */ "unset", std::string() + "The fee estimate mode, must be one of (case insensitive):\n" + " \"" + FeeModes("\"\n\"") + "\""}, {"avoid_reuse", RPCArg::Type::BOOL, /* default */ "true", "(only available if avoid_reuse wallet flag is set) Avoid spending from dirty addresses; addresses are considered\n" " dirty if they have previously been used in a transaction."}, }, @@ -392,26 +438,22 @@ static UniValue sendtoaddress(const JSONRPCRequest& request) HelpExampleCli("sendtoaddress", "\"" + EXAMPLE_ADDRESS[0] + "\" 0.1") + HelpExampleCli("sendtoaddress", "\"" + EXAMPLE_ADDRESS[0] + "\" 0.1 \"donation\" \"seans outpost\"") + HelpExampleCli("sendtoaddress", "\"" + EXAMPLE_ADDRESS[0] + "\" 0.1 \"\" \"\" true") + + HelpExampleCli("sendtoaddress", "\"" + EXAMPLE_ADDRESS[0] + "\" 0.1 \"\" \"\" false true 0.00002 " + (CURRENCY_UNIT + "/kB")) + + HelpExampleCli("sendtoaddress", "\"" + EXAMPLE_ADDRESS[0] + "\" 0.1 \"\" \"\" false true 2 " + (CURRENCY_ATOM + "/B")) + HelpExampleRpc("sendtoaddress", "\"" + EXAMPLE_ADDRESS[0] + "\", 0.1, \"donation\", \"seans outpost\"") }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); LOCK(pwallet->cs_wallet); - CTxDestination dest = DecodeDestination(request.params[0].get_str()); - if (!IsValidDestination(dest)) { - throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Invalid address"); - } - - // Amount - CAmount nAmount = AmountFromValue(request.params[1]); - if (nAmount <= 0) - throw JSONRPCError(RPC_TYPE_ERROR, "Invalid amount for send"); - // Wallet comments mapValue_t mapValue; if (!request.params[2].isNull() && !request.params[2].get_str().empty()) @@ -429,35 +471,30 @@ static UniValue sendtoaddress(const JSONRPCRequest& request) coin_control.m_signal_bip125_rbf = request.params[5].get_bool(); } - if (!request.params[6].isNull()) { - coin_control.m_confirm_target = ParseConfirmTarget(request.params[6], pwallet->chain().estimateMaxBlocks()); - } - - if (!request.params[7].isNull()) { - if (!FeeModeFromString(request.params[7].get_str(), coin_control.m_fee_mode)) { - throw JSONRPCError(RPC_INVALID_PARAMETER, "Invalid estimate_mode parameter"); - } - } - coin_control.m_avoid_address_reuse = GetAvoidReuseFlag(pwallet, request.params[8]); // We also enable partial spend avoidance if reuse avoidance is set. coin_control.m_avoid_partial_spends |= coin_control.m_avoid_address_reuse; + SetFeeEstimateMode(pwallet, coin_control, request.params[7], request.params[6]); + EnsureWalletIsUnlocked(pwallet); - CTransactionRef tx = SendMoney(pwallet, dest, nAmount, fSubtractFeeFromAmount, coin_control, std::move(mapValue)); - return tx->GetHash().GetHex(); + UniValue address_amounts(UniValue::VOBJ); + const std::string address = request.params[0].get_str(); + address_amounts.pushKV(address, request.params[1]); + UniValue subtractFeeFromAmount(UniValue::VARR); + if (fSubtractFeeFromAmount) { + subtractFeeFromAmount.push_back(address); + } + + std::vector<CRecipient> recipients; + ParseRecipients(address_amounts, subtractFeeFromAmount, recipients); + + return SendMoney(pwallet, coin_control, recipients, mapValue); } static UniValue listaddressgroupings(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"listaddressgroupings", "\nLists groups of addresses which have had their common ownership\n" "made public by common use as inputs or as the resulting change\n" @@ -483,6 +520,10 @@ static UniValue listaddressgroupings(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -513,13 +554,6 @@ static UniValue listaddressgroupings(const JSONRPCRequest& request) static UniValue signmessage(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"signmessage", "\nSign a message with the private key of an address" + HELP_REQUIRING_PASSPHRASE, @@ -542,6 +576,10 @@ static UniValue signmessage(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); EnsureWalletIsUnlocked(pwallet); @@ -618,13 +656,6 @@ static CAmount GetReceived(const CWallet& wallet, const UniValue& params, bool b static UniValue getreceivedbyaddress(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"getreceivedbyaddress", "\nReturns the total amount received by the given address in transactions with at least minconf confirmations.\n", { @@ -646,6 +677,10 @@ static UniValue getreceivedbyaddress(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -658,13 +693,6 @@ static UniValue getreceivedbyaddress(const JSONRPCRequest& request) static UniValue getreceivedbylabel(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"getreceivedbylabel", "\nReturns the total amount received by addresses with <label> in transactions with at least [minconf] confirmations.\n", { @@ -686,6 +714,10 @@ static UniValue getreceivedbylabel(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -698,13 +730,6 @@ static UniValue getreceivedbylabel(const JSONRPCRequest& request) static UniValue getbalance(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"getbalance", "\nReturns the total available balance.\n" "The available balance is what the wallet considers currently spendable, and is\n" @@ -728,6 +753,10 @@ static UniValue getbalance(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -755,13 +784,6 @@ static UniValue getbalance(const JSONRPCRequest& request) static UniValue getunconfirmedbalance(const JSONRPCRequest &request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"getunconfirmedbalance", "DEPRECATED\nIdentical to getbalances().mine.untrusted_pending\n", {}, @@ -769,6 +791,10 @@ static UniValue getunconfirmedbalance(const JSONRPCRequest &request) RPCExamples{""}, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -781,13 +807,6 @@ static UniValue getunconfirmedbalance(const JSONRPCRequest &request) static UniValue sendmany(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"sendmany", "\nSend multiple times. Amounts are double-precision floating point numbers." + HELP_REQUIRING_PASSPHRASE, @@ -809,11 +828,9 @@ static UniValue sendmany(const JSONRPCRequest& request) }, }, {"replaceable", RPCArg::Type::BOOL, /* default */ "wallet default", "Allow this transaction to be replaced by a transaction with higher fees via BIP 125"}, - {"conf_target", RPCArg::Type::NUM, /* default */ "wallet default", "Confirmation target (in blocks)"}, - {"estimate_mode", RPCArg::Type::STR, /* default */ "UNSET", "The fee estimate mode, must be one of:\n" - " \"UNSET\"\n" - " \"ECONOMICAL\"\n" - " \"CONSERVATIVE\""}, + {"conf_target", RPCArg::Type::NUM, /* default */ "wallet default", "Confirmation target (in blocks), or fee rate (for " + CURRENCY_UNIT + "/kB or " + CURRENCY_ATOM + "/B estimate modes)"}, + {"estimate_mode", RPCArg::Type::STR, /* default */ "unset", std::string() + "The fee estimate mode, must be one of (case insensitive):\n" + " \"" + FeeModes("\"\n\"") + "\""}, }, RPCResult{ RPCResult::Type::STR_HEX, "txid", "The transaction id for the send. Only 1 transaction is created regardless of\n" @@ -831,6 +848,10 @@ static UniValue sendmany(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -855,73 +876,16 @@ static UniValue sendmany(const JSONRPCRequest& request) coin_control.m_signal_bip125_rbf = request.params[5].get_bool(); } - if (!request.params[6].isNull()) { - coin_control.m_confirm_target = ParseConfirmTarget(request.params[6], pwallet->chain().estimateMaxBlocks()); - } + SetFeeEstimateMode(pwallet, coin_control, request.params[7], request.params[6]); - if (!request.params[7].isNull()) { - if (!FeeModeFromString(request.params[7].get_str(), coin_control.m_fee_mode)) { - throw JSONRPCError(RPC_INVALID_PARAMETER, "Invalid estimate_mode parameter"); - } - } - - std::set<CTxDestination> destinations; - std::vector<CRecipient> vecSend; - - std::vector<std::string> keys = sendTo.getKeys(); - for (const std::string& name_ : keys) { - CTxDestination dest = DecodeDestination(name_); - if (!IsValidDestination(dest)) { - throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, std::string("Invalid Bitcoin address: ") + name_); - } - - if (destinations.count(dest)) { - throw JSONRPCError(RPC_INVALID_PARAMETER, std::string("Invalid parameter, duplicated address: ") + name_); - } - destinations.insert(dest); - - CScript scriptPubKey = GetScriptForDestination(dest); - CAmount nAmount = AmountFromValue(sendTo[name_]); - if (nAmount <= 0) - throw JSONRPCError(RPC_TYPE_ERROR, "Invalid amount for send"); - - bool fSubtractFeeFromAmount = false; - for (unsigned int idx = 0; idx < subtractFeeFromAmount.size(); idx++) { - const UniValue& addr = subtractFeeFromAmount[idx]; - if (addr.get_str() == name_) - fSubtractFeeFromAmount = true; - } - - CRecipient recipient = {scriptPubKey, nAmount, fSubtractFeeFromAmount}; - vecSend.push_back(recipient); - } + std::vector<CRecipient> recipients; + ParseRecipients(sendTo, subtractFeeFromAmount, recipients); - EnsureWalletIsUnlocked(pwallet); - - // Shuffle recipient list - std::shuffle(vecSend.begin(), vecSend.end(), FastRandomContext()); - - // Send - CAmount nFeeRequired = 0; - int nChangePosRet = -1; - bilingual_str error; - CTransactionRef tx; - bool fCreated = pwallet->CreateTransaction(vecSend, tx, nFeeRequired, nChangePosRet, error, coin_control); - if (!fCreated) - throw JSONRPCError(RPC_WALLET_INSUFFICIENT_FUNDS, error.original); - pwallet->CommitTransaction(tx, std::move(mapValue), {} /* orderForm */); - return tx->GetHash().GetHex(); + return SendMoney(pwallet, coin_control, recipients, std::move(mapValue)); } static UniValue addmultisigaddress(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"addmultisigaddress", "\nAdd an nrequired-to-sign multisignature address to the wallet. Requires a new wallet backup.\n" "Each key is a Bitcoin address or hex-encoded public key.\n" @@ -954,6 +918,10 @@ static UniValue addmultisigaddress(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LegacyScriptPubKeyMan& spk_man = EnsureLegacyScriptPubKeyMan(*pwallet); LOCK2(pwallet->cs_wallet, spk_man.cs_KeyStore); @@ -992,7 +960,7 @@ static UniValue addmultisigaddress(const JSONRPCRequest& request) UniValue result(UniValue::VOBJ); result.pushKV("address", EncodeDestination(dest)); - result.pushKV("redeemScript", HexStr(inner.begin(), inner.end())); + result.pushKV("redeemScript", HexStr(inner)); result.pushKV("descriptor", descriptor->ToString()); return result; } @@ -1157,13 +1125,6 @@ static UniValue ListReceived(const CWallet* const pwallet, const UniValue& param static UniValue listreceivedbyaddress(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"listreceivedbyaddress", "\nList balances by receiving address.\n", { @@ -1197,6 +1158,10 @@ static UniValue listreceivedbyaddress(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -1208,13 +1173,6 @@ static UniValue listreceivedbyaddress(const JSONRPCRequest& request) static UniValue listreceivedbylabel(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"listreceivedbylabel", "\nList received transactions by label.\n", { @@ -1241,6 +1199,10 @@ static UniValue listreceivedbylabel(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -1369,13 +1331,6 @@ static const std::vector<RPCResult> TransactionDescriptionString() UniValue listtransactions(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"listtransactions", "\nIf a label name is provided, this will return only incoming transactions paying to addresses with the specified label.\n" "\nReturns up to 'count' most recent transactions skipping the first 'from' transactions.\n", @@ -1423,6 +1378,10 @@ UniValue listtransactions(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -1482,12 +1441,6 @@ UniValue listtransactions(const JSONRPCRequest& request) static UniValue listsinceblock(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const pwallet = GetWalletForJSONRPCRequest(request); - - if (!EnsureWalletIsAvailable(pwallet.get(), request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"listsinceblock", "\nGet all transactions in blocks since block [blockhash], or all transactions if omitted.\n" "If \"blockhash\" is no longer a part of the main chain, transactions from the fork point onward are included.\n" @@ -1531,7 +1484,7 @@ static UniValue listsinceblock(const JSONRPCRequest& request) {RPCResult::Type::ARR, "removed", "<structure is the same as \"transactions\" above, only present if include_removed=true>\n" "Note: transactions that were re-added in the active chain will appear as-is in this array, and may thus have a positive confirmation count." , {{RPCResult::Type::ELISION, "", ""},}}, - {RPCResult::Type::STR_HEX, "lastblock", "The hash of the block (target_confirmations-1) from the best block on the main chain. This is typically used to feed back into listsinceblock the next time you call it. So you would generally use a target_confirmations of say 6, so you will be continually re-notified of transactions until they've reached 6 confirmations plus any new ones"}, + {RPCResult::Type::STR_HEX, "lastblock", "The hash of the block (target_confirmations-1) from the best block on the main chain, or the genesis hash if the referenced block does not exist yet. This is typically used to feed back into listsinceblock the next time you call it. So you would generally use a target_confirmations of say 6, so you will be continually re-notified of transactions until they've reached 6 confirmations plus any new ones"}, } }, RPCExamples{ @@ -1541,6 +1494,9 @@ static UniValue listsinceblock(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const pwallet = GetWalletForJSONRPCRequest(request); + if (!pwallet) return NullUniValue; + const CWallet& wallet = *pwallet; // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now @@ -1611,6 +1567,7 @@ static UniValue listsinceblock(const JSONRPCRequest& request) } uint256 lastblock; + target_confirms = std::min(target_confirms, wallet.GetLastBlockHeight() + 1); CHECK_NONFATAL(wallet.chain().findAncestorByHeight(wallet.GetLastBlockHash(), wallet.GetLastBlockHeight() + 1 - target_confirms, FoundBlock().hash(lastblock))); UniValue ret(UniValue::VOBJ); @@ -1623,13 +1580,6 @@ static UniValue listsinceblock(const JSONRPCRequest& request) static UniValue gettransaction(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"gettransaction", "\nGet detailed information about in-wallet transaction <txid>\n", { @@ -1684,6 +1634,10 @@ static UniValue gettransaction(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -1736,13 +1690,6 @@ static UniValue gettransaction(const JSONRPCRequest& request) static UniValue abandontransaction(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"abandontransaction", "\nMark in-wallet transaction <txid> as abandoned\n" "This will mark this transaction and all its in-wallet descendants as abandoned which will allow\n" @@ -1759,6 +1706,10 @@ static UniValue abandontransaction(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -1780,13 +1731,6 @@ static UniValue abandontransaction(const JSONRPCRequest& request) static UniValue backupwallet(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"backupwallet", "\nSafely copies current wallet file to destination, which can be a directory or a path with filename.\n", { @@ -1799,6 +1743,10 @@ static UniValue backupwallet(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -1816,13 +1764,6 @@ static UniValue backupwallet(const JSONRPCRequest& request) static UniValue keypoolrefill(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"keypoolrefill", "\nFills the keypool."+ HELP_REQUIRING_PASSPHRASE, @@ -1836,6 +1777,10 @@ static UniValue keypoolrefill(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + if (pwallet->IsLegacy() && pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)) { throw JSONRPCError(RPC_WALLET_ERROR, "Error: Private keys are disabled for this wallet"); } @@ -1863,13 +1808,6 @@ static UniValue keypoolrefill(const JSONRPCRequest& request) static UniValue walletpassphrase(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"walletpassphrase", "\nStores the wallet decryption key in memory for 'timeout' seconds.\n" "This is needed prior to performing transactions related to private keys such as sending bitcoins\n" @@ -1891,6 +1829,10 @@ static UniValue walletpassphrase(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + int64_t nSleepTime; int64_t relock_time; // Prevent concurrent calls to walletpassphrase with the same wallet. @@ -1960,13 +1902,6 @@ static UniValue walletpassphrase(const JSONRPCRequest& request) static UniValue walletpassphrasechange(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"walletpassphrasechange", "\nChanges the wallet passphrase from 'oldpassphrase' to 'newpassphrase'.\n", { @@ -1980,6 +1915,10 @@ static UniValue walletpassphrasechange(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); if (!pwallet->IsCrypted()) { @@ -2010,13 +1949,6 @@ static UniValue walletpassphrasechange(const JSONRPCRequest& request) static UniValue walletlock(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"walletlock", "\nRemoves the wallet encryption key from memory, locking the wallet.\n" "After calling this method, you will need to call walletpassphrase again\n" @@ -2035,6 +1967,10 @@ static UniValue walletlock(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); if (!pwallet->IsCrypted()) { @@ -2050,13 +1986,6 @@ static UniValue walletlock(const JSONRPCRequest& request) static UniValue encryptwallet(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"encryptwallet", "\nEncrypts the wallet with 'passphrase'. This is for first time encryption.\n" "After this, any calls that interact with private keys such as sending or signing \n" @@ -2081,6 +2010,10 @@ static UniValue encryptwallet(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); if (pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)) { @@ -2110,13 +2043,6 @@ static UniValue encryptwallet(const JSONRPCRequest& request) static UniValue lockunspent(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"lockunspent", "\nUpdates list of temporarily unspendable outputs.\n" "Temporarily lock (unlock=false) or unlock (unlock=true) specified transaction outputs.\n" @@ -2155,6 +2081,10 @@ static UniValue lockunspent(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -2236,13 +2166,6 @@ static UniValue lockunspent(const JSONRPCRequest& request) static UniValue listlockunspent(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"listlockunspent", "\nReturns list of temporarily unspendable outputs.\n" "See the lockunspent call to lock and unlock transactions for spending.\n", @@ -2271,6 +2194,10 @@ static UniValue listlockunspent(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); std::vector<COutPoint> vOutpts; @@ -2291,13 +2218,6 @@ static UniValue listlockunspent(const JSONRPCRequest& request) static UniValue settxfee(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"settxfee", "\nSet the transaction fee per kB for this wallet. Overrides the global -paytxfee command line parameter.\n" "Can be deactivated by passing 0 as the fee. In that case automatic fee selection will be used by default.\n", @@ -2313,6 +2233,10 @@ static UniValue settxfee(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); CAmount nAmount = AmountFromValue(request.params[0]); @@ -2334,12 +2258,6 @@ static UniValue settxfee(const JSONRPCRequest& request) static UniValue getbalances(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const rpc_wallet = GetWalletForJSONRPCRequest(request); - if (!EnsureWalletIsAvailable(rpc_wallet.get(), request.fHelp)) { - return NullUniValue; - } - CWallet& wallet = *rpc_wallet; - RPCHelpMan{ "getbalances", "Returns an object with all balances in " + CURRENCY_UNIT + ".\n", @@ -2367,6 +2285,10 @@ static UniValue getbalances(const JSONRPCRequest& request) HelpExampleRpc("getbalances", "")}, }.Check(request); + std::shared_ptr<CWallet> const rpc_wallet = GetWalletForJSONRPCRequest(request); + if (!rpc_wallet) return NullUniValue; + CWallet& wallet = *rpc_wallet; + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now wallet.BlockUntilSyncedToCurrentChain(); @@ -2401,13 +2323,6 @@ static UniValue getbalances(const JSONRPCRequest& request) static UniValue getwalletinfo(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"getwalletinfo", "Returns an object containing various wallet state info.\n", {}, @@ -2424,7 +2339,7 @@ static UniValue getwalletinfo(const JSONRPCRequest& request) {RPCResult::Type::NUM_TIME, "keypoololdest", "the " + UNIX_EPOCH_TIME + " of the oldest pre-generated key in the key pool. Legacy wallets only."}, {RPCResult::Type::NUM, "keypoolsize", "how many new keys are pre-generated (only counts external keys)"}, {RPCResult::Type::NUM, "keypoolsize_hd_internal", "how many new keys are pre-generated for internal use (used for change outputs, only appears if the wallet is using this feature, otherwise external keys are used)"}, - {RPCResult::Type::NUM_TIME, "unlocked_until", "the " + UNIX_EPOCH_TIME + " until which the wallet is unlocked for transfers, or 0 if the wallet is locked"}, + {RPCResult::Type::NUM_TIME, "unlocked_until", /* optional */ true, "the " + UNIX_EPOCH_TIME + " until which the wallet is unlocked for transfers, or 0 if the wallet is locked (only present for passphrase-encrypted wallets)"}, {RPCResult::Type::STR_AMOUNT, "paytxfee", "the transaction fee configuration, set in " + CURRENCY_UNIT + "/kB"}, {RPCResult::Type::STR_HEX, "hdseedid", /* optional */ true, "the Hash160 of the HD seed (only present when HD is enabled)"}, {RPCResult::Type::BOOL, "private_keys_enabled", "false if privatekeys are disabled for this wallet (enforced watch-only wallet)"}, @@ -2443,6 +2358,10 @@ static UniValue getwalletinfo(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); @@ -2550,12 +2469,7 @@ static UniValue listwallets(const JSONRPCRequest& request) UniValue obj(UniValue::VARR); for (const std::shared_ptr<CWallet>& wallet : GetWallets()) { - if (!EnsureWalletIsAvailable(wallet.get(), request.fHelp)) { - return NullUniValue; - } - LOCK(wallet->cs_wallet); - obj.push_back(wallet->GetName()); } @@ -2584,6 +2498,7 @@ static UniValue loadwallet(const JSONRPCRequest& request) }, }.Check(request); + WalletContext& context = EnsureWalletContext(request.context); WalletLocation location(request.params[0].get_str()); if (!location.Exists()) { @@ -2598,7 +2513,7 @@ static UniValue loadwallet(const JSONRPCRequest& request) bilingual_str error; std::vector<bilingual_str> warnings; - std::shared_ptr<CWallet> const wallet = LoadWallet(*g_rpc_chain, location, error, warnings); + std::shared_ptr<CWallet> const wallet = LoadWallet(*context.chain, location, error, warnings); if (!wallet) throw JSONRPCError(RPC_WALLET_ERROR, error.original); UniValue obj(UniValue::VOBJ); @@ -2610,13 +2525,6 @@ static UniValue loadwallet(const JSONRPCRequest& request) static UniValue setwalletflag(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - std::string flags = ""; for (auto& it : WALLET_FLAG_MAP) if (it.second & MUTABLE_WALLET_FLAGS) @@ -2641,6 +2549,10 @@ static UniValue setwalletflag(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + std::string flag_str = request.params[0].get_str(); bool value = request.params[1].isNull() || request.params[1].get_bool(); @@ -2702,6 +2614,7 @@ static UniValue createwallet(const JSONRPCRequest& request) }, }.Check(request); + WalletContext& context = EnsureWalletContext(request.context); uint64_t flags = 0; if (!request.params[1].isNull() && request.params[1].get_bool()) { flags |= WALLET_FLAG_DISABLE_PRIVATE_KEYS; @@ -2731,7 +2644,7 @@ static UniValue createwallet(const JSONRPCRequest& request) bilingual_str error; std::shared_ptr<CWallet> wallet; - WalletCreationStatus status = CreateWallet(*g_rpc_chain, passphrase, flags, request.params[0].get_str(), error, warnings, wallet); + WalletCreationStatus status = CreateWallet(*context.chain, passphrase, flags, request.params[0].get_str(), error, warnings, wallet); switch (status) { case WalletCreationStatus::CREATION_FAILED: throw JSONRPCError(RPC_WALLET_ERROR, error.original); @@ -2792,13 +2705,6 @@ static UniValue unloadwallet(const JSONRPCRequest& request) static UniValue listunspent(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{ "listunspent", "\nReturns array of unspent transaction outputs\n" @@ -2856,6 +2762,10 @@ static UniValue listunspent(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + int nMinDepth = 1; if (!request.params[0].isNull()) { RPCTypeCheckArgument(request.params[0], UniValue::VNUM); @@ -2957,7 +2867,7 @@ static UniValue listunspent(const JSONRPCRequest& request) const CScriptID& hash = CScriptID(boost::get<ScriptHash>(address)); CScript redeemScript; if (provider->GetCScript(hash, redeemScript)) { - entry.pushKV("redeemScript", HexStr(redeemScript.begin(), redeemScript.end())); + entry.pushKV("redeemScript", HexStr(redeemScript)); // Now check if the redeemScript is actually a P2WSH script CTxDestination witness_destination; if (redeemScript.IsPayToWitnessScriptHash()) { @@ -2969,7 +2879,7 @@ static UniValue listunspent(const JSONRPCRequest& request) CRIPEMD160().Write(whash.begin(), whash.size()).Finalize(id.begin()); CScript witnessScript; if (provider->GetCScript(id, witnessScript)) { - entry.pushKV("witnessScript", HexStr(witnessScript.begin(), witnessScript.end())); + entry.pushKV("witnessScript", HexStr(witnessScript)); } } } @@ -2979,13 +2889,13 @@ static UniValue listunspent(const JSONRPCRequest& request) CRIPEMD160().Write(whash.begin(), whash.size()).Finalize(id.begin()); CScript witnessScript; if (provider->GetCScript(id, witnessScript)) { - entry.pushKV("witnessScript", HexStr(witnessScript.begin(), witnessScript.end())); + entry.pushKV("witnessScript", HexStr(witnessScript)); } } } } - entry.pushKV("scriptPubKey", HexStr(scriptPubKey.begin(), scriptPubKey.end())); + entry.pushKV("scriptPubKey", HexStr(scriptPubKey)); entry.pushKV("amount", ValueFromAmount(out.tx->tx->vout[out.i].nValue)); entry.pushKV("confirmations", out.nDepth); entry.pushKV("spendable", out.fSpendable); @@ -3005,13 +2915,12 @@ static UniValue listunspent(const JSONRPCRequest& request) return results; } -void FundTransaction(CWallet* const pwallet, CMutableTransaction& tx, CAmount& fee_out, int& change_position, UniValue options) +void FundTransaction(CWallet* const pwallet, CMutableTransaction& tx, CAmount& fee_out, int& change_position, UniValue options, CCoinControl& coinControl) { // Make sure the results are valid at least up to the most recent block // the user could have gotten from another RPC command prior to now pwallet->BlockUntilSyncedToCurrentChain(); - CCoinControl coinControl; change_position = -1; bool lockUnspents = false; UniValue subtractFeeFromOutputs; @@ -3026,6 +2935,7 @@ void FundTransaction(CWallet* const pwallet, CMutableTransaction& tx, CAmount& f RPCTypeCheckArgument(options, UniValue::VOBJ); RPCTypeCheckObj(options, { + {"add_inputs", UniValueType(UniValue::VBOOL)}, {"changeAddress", UniValueType(UniValue::VSTR)}, {"changePosition", UniValueType(UniValue::VNUM)}, {"change_type", UniValueType(UniValue::VSTR)}, @@ -3039,6 +2949,10 @@ void FundTransaction(CWallet* const pwallet, CMutableTransaction& tx, CAmount& f }, true, true); + if (options.exists("add_inputs") ) { + coinControl.m_add_inputs = options["add_inputs"].get_bool(); + } + if (options.exists("changeAddress")) { CTxDestination dest = DecodeDestination(options["changeAddress"].get_str()); @@ -3056,10 +2970,11 @@ void FundTransaction(CWallet* const pwallet, CMutableTransaction& tx, CAmount& f if (options.exists("changeAddress")) { throw JSONRPCError(RPC_INVALID_PARAMETER, "Cannot specify both changeAddress and address_type options"); } - coinControl.m_change_type = pwallet->m_default_change_type; - if (!ParseOutputType(options["change_type"].get_str(), *coinControl.m_change_type)) { + OutputType out_type; + if (!ParseOutputType(options["change_type"].get_str(), out_type)) { throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, strprintf("Unknown change type '%s'", options["change_type"].get_str())); } + coinControl.m_change_type.emplace(out_type); } coinControl.fAllowWatchOnly = ParseIncludeWatchonly(options["includeWatching"], *pwallet); @@ -3069,6 +2984,12 @@ void FundTransaction(CWallet* const pwallet, CMutableTransaction& tx, CAmount& f if (options.exists("feeRate")) { + if (options.exists("conf_target")) { + throw JSONRPCError(RPC_INVALID_PARAMETER, "Cannot specify both conf_target and feeRate"); + } + if (options.exists("estimate_mode")) { + throw JSONRPCError(RPC_INVALID_PARAMETER, "Cannot specify both estimate_mode and feeRate"); + } coinControl.m_feerate = CFeeRate(AmountFromValue(options["feeRate"])); coinControl.fOverrideFeeRate = true; } @@ -3079,20 +3000,7 @@ void FundTransaction(CWallet* const pwallet, CMutableTransaction& tx, CAmount& f if (options.exists("replaceable")) { coinControl.m_signal_bip125_rbf = options["replaceable"].get_bool(); } - if (options.exists("conf_target")) { - if (options.exists("feeRate")) { - throw JSONRPCError(RPC_INVALID_PARAMETER, "Cannot specify both conf_target and feeRate"); - } - coinControl.m_confirm_target = ParseConfirmTarget(options["conf_target"], pwallet->chain().estimateMaxBlocks()); - } - if (options.exists("estimate_mode")) { - if (options.exists("feeRate")) { - throw JSONRPCError(RPC_INVALID_PARAMETER, "Cannot specify both estimate_mode and feeRate"); - } - if (!FeeModeFromString(options["estimate_mode"].get_str(), coinControl.m_fee_mode)) { - throw JSONRPCError(RPC_INVALID_PARAMETER, "Invalid estimate_mode parameter"); - } - } + SetFeeEstimateMode(pwallet, coinControl, options["estimate_mode"], options["conf_target"]); } } else { // if options is null and not a bool @@ -3125,16 +3033,9 @@ void FundTransaction(CWallet* const pwallet, CMutableTransaction& tx, CAmount& f static UniValue fundrawtransaction(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"fundrawtransaction", - "\nAdd inputs to a transaction until it has enough in value to meet its out value.\n" - "This will not modify existing inputs, and will add at most one change output to the outputs.\n" + "\nIf the transaction has no inputs, they will be automatically selected to meet its out value.\n" + "It will add at most one change output to the outputs.\n" "No existing outputs will be modified unless \"subtractFeeFromOutputs\" is specified.\n" "Note that inputs which were signed may need to be resigned after completion since in/outputs have been added.\n" "The inputs added will not be signed, use signrawtransactionwithkey\n" @@ -3148,6 +3049,7 @@ static UniValue fundrawtransaction(const JSONRPCRequest& request) {"hexstring", RPCArg::Type::STR_HEX, RPCArg::Optional::NO, "The hex string of the raw transaction"}, {"options", RPCArg::Type::OBJ, RPCArg::Optional::OMITTED_NAMED_ARG, "for backward compatibility: passing in a true instead of an object will result in {\"includeWatching\":true}", { + {"add_inputs", RPCArg::Type::BOOL, /* default */ "true", "For a transaction with existing inputs, automatically include more if they are not enough."}, {"changeAddress", RPCArg::Type::STR, /* default */ "pool address", "The bitcoin address to receive the change"}, {"changePosition", RPCArg::Type::NUM, /* default */ "random", "The index of the change output"}, {"change_type", RPCArg::Type::STR, /* default */ "set by -changetype", "The output type to use. Only valid if changeAddress is not specified. Options are \"legacy\", \"p2sh-segwit\", and \"bech32\"."}, @@ -3166,11 +3068,9 @@ static UniValue fundrawtransaction(const JSONRPCRequest& request) }, {"replaceable", RPCArg::Type::BOOL, /* default */ "wallet default", "Marks this transaction as BIP125 replaceable.\n" " Allows this transaction to be replaced by a transaction with higher fees"}, - {"conf_target", RPCArg::Type::NUM, /* default */ "wallet default", "Confirmation target (in blocks)"}, - {"estimate_mode", RPCArg::Type::STR, /* default */ "UNSET", "The fee estimate mode, must be one of:\n" - " \"UNSET\"\n" - " \"ECONOMICAL\"\n" - " \"CONSERVATIVE\""}, + {"conf_target", RPCArg::Type::NUM, /* default */ "wallet default", "Confirmation target (in blocks), or fee rate (for " + CURRENCY_UNIT + "/kB or " + CURRENCY_ATOM + "/B estimate modes)"}, + {"estimate_mode", RPCArg::Type::STR, /* default */ "unset", std::string() + "The fee estimate mode, must be one of (case insensitive):\n" + " \"" + FeeModes("\"\n\"") + "\""}, }, "options"}, {"iswitness", RPCArg::Type::BOOL, /* default */ "depends on heuristic tests", "Whether the transaction hex is a serialized witness transaction.\n" @@ -3201,6 +3101,10 @@ static UniValue fundrawtransaction(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + RPCTypeCheck(request.params, {UniValue::VSTR, UniValueType(), UniValue::VBOOL}); // parse hex string from parameter @@ -3213,7 +3117,10 @@ static UniValue fundrawtransaction(const JSONRPCRequest& request) CAmount fee; int change_position; - FundTransaction(pwallet, tx, fee, change_position, request.params[1]); + CCoinControl coin_control; + // Automatically select (additional) coins. Can be overridden by options.add_inputs. + coin_control.m_add_inputs = true; + FundTransaction(pwallet, tx, fee, change_position, request.params[1], coin_control); UniValue result(UniValue::VOBJ); result.pushKV("hex", EncodeHexTx(CTransaction(tx))); @@ -3225,13 +3132,6 @@ static UniValue fundrawtransaction(const JSONRPCRequest& request) UniValue signrawtransactionwithwallet(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"signrawtransactionwithwallet", "\nSign inputs for raw transaction (serialized, hex-encoded).\n" "The second optional argument (may be null) is an array of previous transaction outputs that\n" @@ -3285,6 +3185,10 @@ UniValue signrawtransactionwithwallet(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + RPCTypeCheck(request.params, {UniValue::VSTR, UniValue::VARR, UniValue::VSTR}, true); CMutableTransaction mtx; @@ -3319,63 +3223,72 @@ UniValue signrawtransactionwithwallet(const JSONRPCRequest& request) static UniValue bumpfee(const JSONRPCRequest& request) { + bool want_psbt = request.strMethod == "psbtbumpfee"; + + RPCHelpMan{request.strMethod, + "\nBumps the fee of an opt-in-RBF transaction T, replacing it with a new transaction B.\n" + + std::string(want_psbt ? "Returns a PSBT instead of creating and signing a new transaction.\n" : "") + + "An opt-in RBF transaction with the given txid must be in the wallet.\n" + "The command will pay the additional fee by reducing change outputs or adding inputs when necessary. It may add a new change output if one does not already exist.\n" + "All inputs in the original transaction will be included in the replacement transaction.\n" + "The command will fail if the wallet or mempool contains a transaction that spends one of T's outputs.\n" + "By default, the new fee will be calculated automatically using estimatesmartfee.\n" + "The user can specify a confirmation target for estimatesmartfee.\n" + "Alternatively, the user can specify a fee_rate (" + CURRENCY_UNIT + " per kB) for the new transaction.\n" + "At a minimum, the new fee rate must be high enough to pay an additional new relay fee (incrementalfee\n" + "returned by getnetworkinfo) to enter the node's mempool.\n", + { + {"txid", RPCArg::Type::STR_HEX, RPCArg::Optional::NO, "The txid to be bumped"}, + {"options", RPCArg::Type::OBJ, RPCArg::Optional::OMITTED_NAMED_ARG, "", + { + {"conf_target", RPCArg::Type::NUM, /* default */ "wallet default", "Confirmation target (in blocks)"}, + {"fee_rate", RPCArg::Type::NUM, /* default */ "fall back to 'conf_target'", "fee rate (NOT total fee) to pay, in " + CURRENCY_UNIT + " per kB\n" + " Specify a fee rate instead of relying on the built-in fee estimator.\n" + "Must be at least 0.0001 " + CURRENCY_UNIT + " per kB higher than the current transaction fee rate.\n"}, + {"replaceable", RPCArg::Type::BOOL, /* default */ "true", "Whether the new transaction should still be\n" + " marked bip-125 replaceable. If true, the sequence numbers in the transaction will\n" + " be left unchanged from the original. If false, any input sequence numbers in the\n" + " original transaction that were less than 0xfffffffe will be increased to 0xfffffffe\n" + " so the new transaction will not be explicitly bip-125 replaceable (though it may\n" + " still be replaceable in practice, for example if it has unconfirmed ancestors which\n" + " are replaceable)."}, + {"estimate_mode", RPCArg::Type::STR, /* default */ "unset", std::string() + "The fee estimate mode, must be one of (case insensitive):\n" + " \"" + FeeModes("\"\n\"") + "\""}, + }, + "options"}, + }, + RPCResult{ + RPCResult::Type::OBJ, "", "", Cat(Cat<std::vector<RPCResult>>( + { + {RPCResult::Type::STR, "psbt", "The base64-encoded unsigned PSBT of the new transaction." + std::string(want_psbt ? "" : " Only returned when wallet private keys are disabled. (DEPRECATED)")}, + }, + want_psbt ? std::vector<RPCResult>{} : std::vector<RPCResult>{{RPCResult::Type::STR_HEX, "txid", "The id of the new transaction. Only returned when wallet private keys are enabled."}} + ), + { + {RPCResult::Type::STR_AMOUNT, "origfee", "The fee of the replaced transaction."}, + {RPCResult::Type::STR_AMOUNT, "fee", "The fee of the new transaction."}, + {RPCResult::Type::ARR, "errors", "Errors encountered during processing (may be empty).", + { + {RPCResult::Type::STR, "", ""}, + }}, + }) + }, + RPCExamples{ + "\nBump the fee, get the new transaction\'s" + std::string(want_psbt ? "psbt" : "txid") + "\n" + + HelpExampleCli(request.strMethod, "<txid>") + }, + }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) - return NullUniValue; - - RPCHelpMan{"bumpfee", - "\nBumps the fee of an opt-in-RBF transaction T, replacing it with a new transaction B.\n" - "An opt-in RBF transaction with the given txid must be in the wallet.\n" - "The command will pay the additional fee by reducing change outputs or adding inputs when necessary. It may add a new change output if one does not already exist.\n" - "All inputs in the original transaction will be included in the replacement transaction.\n" - "The command will fail if the wallet or mempool contains a transaction that spends one of T's outputs.\n" - "By default, the new fee will be calculated automatically using estimatesmartfee.\n" - "The user can specify a confirmation target for estimatesmartfee.\n" - "Alternatively, the user can specify a fee_rate (" + CURRENCY_UNIT + " per kB) for the new transaction.\n" - "At a minimum, the new fee rate must be high enough to pay an additional new relay fee (incrementalfee\n" - "returned by getnetworkinfo) to enter the node's mempool.\n", - { - {"txid", RPCArg::Type::STR_HEX, RPCArg::Optional::NO, "The txid to be bumped"}, - {"options", RPCArg::Type::OBJ, RPCArg::Optional::OMITTED_NAMED_ARG, "", - { - {"confTarget", RPCArg::Type::NUM, /* default */ "wallet default", "Confirmation target (in blocks)"}, - {"fee_rate", RPCArg::Type::NUM, /* default */ "fall back to 'confTarget'", "fee rate (NOT total fee) to pay, in " + CURRENCY_UNIT + " per kB\n" - " Specify a fee rate instead of relying on the built-in fee estimator.\n" - "Must be at least 0.0001 " + CURRENCY_UNIT + " per kB higher than the current transaction fee rate.\n"}, - {"replaceable", RPCArg::Type::BOOL, /* default */ "true", "Whether the new transaction should still be\n" - " marked bip-125 replaceable. If true, the sequence numbers in the transaction will\n" - " be left unchanged from the original. If false, any input sequence numbers in the\n" - " original transaction that were less than 0xfffffffe will be increased to 0xfffffffe\n" - " so the new transaction will not be explicitly bip-125 replaceable (though it may\n" - " still be replaceable in practice, for example if it has unconfirmed ancestors which\n" - " are replaceable)."}, - {"estimate_mode", RPCArg::Type::STR, /* default */ "UNSET", "The fee estimate mode, must be one of:\n" - " \"UNSET\"\n" - " \"ECONOMICAL\"\n" - " \"CONSERVATIVE\""}, - }, - "options"}, - }, - RPCResult{ - RPCResult::Type::OBJ, "", "", { - {RPCResult::Type::STR, "psbt", "The base64-encoded unsigned PSBT of the new transaction. Only returned when wallet private keys are disabled."}, - {RPCResult::Type::STR_HEX, "txid", "The id of the new transaction. Only returned when wallet private keys are enabled."}, - {RPCResult::Type::STR_AMOUNT, "origfee", "The fee of the replaced transaction."}, - {RPCResult::Type::STR_AMOUNT, "fee", "The fee of the new transaction."}, - {RPCResult::Type::ARR, "errors", "Errors encountered during processing (may be empty).", - { - {RPCResult::Type::STR, "", ""}, - }}, - } - }, - RPCExamples{ - "\nBump the fee, get the new transaction\'s txid\n" + - HelpExampleCli("bumpfee", "<txid>") - }, - }.Check(request); + if (pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS) && !want_psbt) { + if (!pwallet->chain().rpcEnableDeprecated("bumpfee")) { + throw JSONRPCError(RPC_METHOD_DEPRECATED, "Using bumpfee with wallets that have private keys disabled is deprecated. Use psbtbumpfee instead or restart bitcoind with -deprecatedrpc=bumpfee. This functionality will be removed in 0.22"); + } + want_psbt = true; + } RPCTypeCheck(request.params, {UniValue::VSTR, UniValue::VOBJ}); uint256 hash(ParseHashV(request.params[0], "txid")); @@ -3390,15 +3303,24 @@ static UniValue bumpfee(const JSONRPCRequest& request) RPCTypeCheckObj(options, { {"confTarget", UniValueType(UniValue::VNUM)}, + {"conf_target", UniValueType(UniValue::VNUM)}, {"fee_rate", UniValueType(UniValue::VNUM)}, {"replaceable", UniValueType(UniValue::VBOOL)}, {"estimate_mode", UniValueType(UniValue::VSTR)}, }, true, true); - if (options.exists("confTarget") && options.exists("fee_rate")) { - throw JSONRPCError(RPC_INVALID_PARAMETER, "confTarget can't be set with fee_rate. Please provide either a confirmation target in blocks for automatic fee estimation, or an explicit fee rate."); - } else if (options.exists("confTarget")) { // TODO: alias this to conf_target - coin_control.m_confirm_target = ParseConfirmTarget(options["confTarget"], pwallet->chain().estimateMaxBlocks()); + + if (options.exists("confTarget") && options.exists("conf_target")) { + throw JSONRPCError(RPC_INVALID_PARAMETER, "confTarget and conf_target options should not both be set. Use conf_target (confTarget is deprecated)."); + } + + auto conf_target = options.exists("confTarget") ? options["confTarget"] : options["conf_target"]; + + if (!conf_target.isNull()) { + if (options.exists("fee_rate")) { + throw JSONRPCError(RPC_INVALID_PARAMETER, "conf_target can't be set with fee_rate. Please provide either a confirmation target in blocks for automatic fee estimation, or an explicit fee rate."); + } + coin_control.m_confirm_target = ParseConfirmTarget(conf_target, pwallet->chain().estimateMaxBlocks()); } else if (options.exists("fee_rate")) { CFeeRate fee_rate(AmountFromValue(options["fee_rate"])); if (fee_rate <= CFeeRate(0)) { @@ -3410,11 +3332,7 @@ static UniValue bumpfee(const JSONRPCRequest& request) if (options.exists("replaceable")) { coin_control.m_signal_bip125_rbf = options["replaceable"].get_bool(); } - if (options.exists("estimate_mode")) { - if (!FeeModeFromString(options["estimate_mode"].get_str(), coin_control.m_fee_mode)) { - throw JSONRPCError(RPC_INVALID_PARAMETER, "Invalid estimate_mode parameter"); - } - } + SetFeeEstimateMode(pwallet, coin_control, options["estimate_mode"], conf_target); } // Make sure the results are valid at least up to the most recent block @@ -3456,7 +3374,7 @@ static UniValue bumpfee(const JSONRPCRequest& request) // If wallet private keys are enabled, return the new transaction id, // otherwise return the base64-encoded unsigned PSBT of the new transaction. - if (!pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)) { + if (!want_psbt) { if (!feebumper::SignTransaction(*pwallet, mtx)) { throw JSONRPCError(RPC_WALLET_ERROR, "Can't sign transaction."); } @@ -3489,15 +3407,13 @@ static UniValue bumpfee(const JSONRPCRequest& request) return result; } -UniValue rescanblockchain(const JSONRPCRequest& request) +static UniValue psbtbumpfee(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } + return bumpfee(request); +} +UniValue rescanblockchain(const JSONRPCRequest& request) +{ RPCHelpMan{"rescanblockchain", "\nRescan the local blockchain for wallet related transactions.\n" "Note: Use \"getwalletinfo\" to query the scanning progress.\n", @@ -3518,6 +3434,10 @@ UniValue rescanblockchain(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + WalletRescanReserver reserver(*pwallet); if (!reserver.reserve()) { throw JSONRPCError(RPC_WALLET_ERROR, "Wallet is currently rescanning. Abort existing rescan or wait."); @@ -3580,9 +3500,9 @@ public: { // Always present: script type and redeemscript std::vector<std::vector<unsigned char>> solutions_data; - txnouttype which_type = Solver(subscript, solutions_data); + TxoutType which_type = Solver(subscript, solutions_data); obj.pushKV("script", GetTxnOutputType(which_type)); - obj.pushKV("hex", HexStr(subscript.begin(), subscript.end())); + obj.pushKV("hex", HexStr(subscript)); CTxDestination embedded; if (ExtractDestination(subscript, embedded)) { @@ -3593,18 +3513,18 @@ public: UniValue wallet_detail = boost::apply_visitor(*this, embedded); subobj.pushKVs(wallet_detail); subobj.pushKV("address", EncodeDestination(embedded)); - subobj.pushKV("scriptPubKey", HexStr(subscript.begin(), subscript.end())); + subobj.pushKV("scriptPubKey", HexStr(subscript)); // Always report the pubkey at the top level, so that `getnewaddress()['pubkey']` always works. if (subobj.exists("pubkey")) obj.pushKV("pubkey", subobj["pubkey"]); obj.pushKV("embedded", std::move(subobj)); - } else if (which_type == TX_MULTISIG) { + } else if (which_type == TxoutType::MULTISIG) { // Also report some information on multisig scripts (which do not have a corresponding address). // TODO: abstract out the common functionality between this logic and ExtractDestinations. obj.pushKV("sigsrequired", solutions_data[0][0]); UniValue pubkeys(UniValue::VARR); for (size_t i = 1; i < solutions_data.size() - 1; ++i) { CPubKey key(solutions_data[i].begin(), solutions_data[i].end()); - pubkeys.push_back(HexStr(key.begin(), key.end())); + pubkeys.push_back(HexStr(key)); } obj.pushKV("pubkeys", std::move(pubkeys)); } @@ -3616,7 +3536,7 @@ public: UniValue operator()(const PKHash& pkhash) const { - CKeyID keyID(pkhash); + CKeyID keyID{ToKeyID(pkhash)}; UniValue obj(UniValue::VOBJ); CPubKey vchPubKey; if (provider && provider->GetPubKey(keyID, vchPubKey)) { @@ -3641,7 +3561,7 @@ public: { UniValue obj(UniValue::VOBJ); CPubKey pubkey; - if (provider && provider->GetPubKey(CKeyID(id), pubkey)) { + if (provider && provider->GetPubKey(ToKeyID(id), pubkey)) { obj.pushKV("pubkey", HexStr(pubkey)); } return obj; @@ -3690,13 +3610,6 @@ static UniValue AddressBookDataToJSON(const CAddressBookData& data, const bool v UniValue getaddressinfo(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"getaddressinfo", "\nReturn information about the given bitcoin address.\n" "Some of the information will only be present if the address is in the active wallet.\n", @@ -3729,12 +3642,10 @@ UniValue getaddressinfo(const JSONRPCRequest& request) {RPCResult::Type::STR_HEX, "pubkey", /* optional */ true, "The hex value of the raw public key for single-key addresses (possibly embedded in P2SH or P2WSH)."}, {RPCResult::Type::OBJ, "embedded", /* optional */ true, "Information about the address embedded in P2SH or P2WSH, if relevant and known.", { - {RPCResult::Type::ELISION, "", "Includes all\n" - " getaddressinfo output fields for the embedded address, excluding metadata (timestamp, hdkeypath,\n" - "hdseedid) and relation to the wallet (ismine, iswatchonly)."}, + {RPCResult::Type::ELISION, "", "Includes all getaddressinfo output fields for the embedded address, excluding metadata (timestamp, hdkeypath, hdseedid)\n" + "and relation to the wallet (ismine, iswatchonly)."}, }}, {RPCResult::Type::BOOL, "iscompressed", /* optional */ true, "If the pubkey is compressed."}, - {RPCResult::Type::STR, "label", "DEPRECATED. The label associated with the address. Defaults to \"\". Replaced by the labels array below."}, {RPCResult::Type::NUM_TIME, "timestamp", /* optional */ true, "The creation time of the key, if available, expressed in " + UNIX_EPOCH_TIME + "."}, {RPCResult::Type::STR, "hdkeypath", /* optional */ true, "The HD keypath, if the key is HD and available."}, {RPCResult::Type::STR_HEX, "hdseedid", /* optional */ true, "The Hash160 of the HD seed."}, @@ -3742,12 +3653,7 @@ UniValue getaddressinfo(const JSONRPCRequest& request) {RPCResult::Type::ARR, "labels", "Array of labels associated with the address. Currently limited to one label but returned\n" "as an array to keep the API stable if multiple labels are enabled in the future.", { - {RPCResult::Type::STR, "label name", "The label name. Defaults to \"\"."}, - {RPCResult::Type::OBJ, "", "label data, DEPRECATED, will be removed in 0.21. To re-enable, launch bitcoind with `-deprecatedrpc=labelspurpose`", - { - {RPCResult::Type::STR, "name", "The label name. Defaults to \"\"."}, - {RPCResult::Type::STR, "purpose", "The purpose of the associated address (send or receive)."}, - }}, + {RPCResult::Type::STR, "label name", "Label name (defaults to \"\")."}, }}, } }, @@ -3757,6 +3663,10 @@ UniValue getaddressinfo(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); UniValue ret(UniValue::VOBJ); @@ -3770,7 +3680,7 @@ UniValue getaddressinfo(const JSONRPCRequest& request) ret.pushKV("address", currentAddress); CScript scriptPubKey = GetScriptForDestination(dest); - ret.pushKV("scriptPubKey", HexStr(scriptPubKey.begin(), scriptPubKey.end())); + ret.pushKV("scriptPubKey", HexStr(scriptPubKey)); std::unique_ptr<SigningProvider> provider = pwallet->GetSolvingProvider(scriptPubKey); @@ -3789,14 +3699,6 @@ UniValue getaddressinfo(const JSONRPCRequest& request) UniValue detail = DescribeWalletAddress(pwallet, dest); ret.pushKVs(detail); - // DEPRECATED: Return label field if existing. Currently only one label can - // be associated with an address, so the label should be equivalent to the - // value of the name key/value pair in the labels array below. - const auto* address_book_entry = pwallet->FindAddressBookEntry(dest); - if (pwallet->chain().rpcEnableDeprecated("label") && address_book_entry) { - ret.pushKV("label", address_book_entry->GetLabel()); - } - ret.pushKV("ischange", pwallet->IsChange(scriptPubKey)); ScriptPubKeyMan* spk_man = pwallet->GetScriptPubKeyMan(scriptPubKey); @@ -3806,7 +3708,7 @@ UniValue getaddressinfo(const JSONRPCRequest& request) if (meta->has_key_origin) { ret.pushKV("hdkeypath", WriteHDKeypath(meta->key_origin.path)); ret.pushKV("hdseedid", meta->hd_seed_id.GetHex()); - ret.pushKV("hdmasterfingerprint", HexStr(meta->key_origin.fingerprint, meta->key_origin.fingerprint + 4)); + ret.pushKV("hdmasterfingerprint", HexStr(meta->key_origin.fingerprint)); } } } @@ -3817,14 +3719,9 @@ UniValue getaddressinfo(const JSONRPCRequest& request) // stable if we allow multiple labels to be associated with an address in // the future. UniValue labels(UniValue::VARR); + const auto* address_book_entry = pwallet->FindAddressBookEntry(dest); if (address_book_entry) { - // DEPRECATED: The previous behavior of returning an array containing a - // JSON object of `name` and `purpose` key/value pairs is deprecated. - if (pwallet->chain().rpcEnableDeprecated("labelspurpose")) { - labels.push_back(AddressBookDataToJSON(*address_book_entry, true)); - } else { - labels.push_back(address_book_entry->GetLabel()); - } + labels.push_back(address_book_entry->GetLabel()); } ret.pushKV("labels", std::move(labels)); @@ -3833,13 +3730,6 @@ UniValue getaddressinfo(const JSONRPCRequest& request) static UniValue getaddressesbylabel(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"getaddressesbylabel", "\nReturns the list of addresses assigned the specified label.\n", { @@ -3860,6 +3750,10 @@ static UniValue getaddressesbylabel(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); std::string label = LabelFromValue(request.params[0]); @@ -3893,13 +3787,6 @@ static UniValue getaddressesbylabel(const JSONRPCRequest& request) static UniValue listlabels(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"listlabels", "\nReturns the list of all labels, or labels that are assigned to addresses with a specific purpose.\n", { @@ -3923,6 +3810,10 @@ static UniValue listlabels(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + LOCK(pwallet->cs_wallet); std::string purpose; @@ -3949,13 +3840,6 @@ static UniValue listlabels(const JSONRPCRequest& request) UniValue sethdseed(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"sethdseed", "\nSet or generate a new HD wallet seed. Non-HD wallets will not be upgraded to being a HD wallet. Wallets that are already\n" "HD will have a new HD seed set so that new keys added to the keypool will be derived from this new seed.\n" @@ -3978,6 +3862,10 @@ UniValue sethdseed(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + LegacyScriptPubKeyMan& spk_man = EnsureLegacyScriptPubKeyMan(*pwallet, true); if (pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)) { @@ -4022,13 +3910,6 @@ UniValue sethdseed(const JSONRPCRequest& request) UniValue walletprocesspsbt(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - const CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"walletprocesspsbt", "\nUpdate a PSBT with input information from our wallet and then sign inputs\n" "that we can sign for." + @@ -4057,6 +3938,10 @@ UniValue walletprocesspsbt(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + const CWallet* const pwallet = wallet.get(); + RPCTypeCheck(request.params, {UniValue::VSTR, UniValue::VBOOL, UniValue::VSTR}); // Unserialize the transaction @@ -4089,24 +3974,17 @@ UniValue walletprocesspsbt(const JSONRPCRequest& request) UniValue walletcreatefundedpsbt(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"walletcreatefundedpsbt", - "\nCreates and funds a transaction in the Partially Signed Transaction format. Inputs will be added if supplied inputs are not enough\n" + "\nCreates and funds a transaction in the Partially Signed Transaction format.\n" "Implements the Creator and Updater roles.\n", { - {"inputs", RPCArg::Type::ARR, RPCArg::Optional::NO, "The inputs", + {"inputs", RPCArg::Type::ARR, RPCArg::Optional::NO, "The inputs. Leave empty to add inputs automatically. See add_inputs option.", { {"", RPCArg::Type::OBJ, RPCArg::Optional::OMITTED, "", { {"txid", RPCArg::Type::STR_HEX, RPCArg::Optional::NO, "The transaction id"}, {"vout", RPCArg::Type::NUM, RPCArg::Optional::NO, "The output number"}, - {"sequence", RPCArg::Type::NUM, RPCArg::Optional::NO, "The sequence number"}, + {"sequence", RPCArg::Type::NUM, /* default */ "depends on the value of the 'locktime' and 'options.replaceable' arguments", "The sequence number"}, }, }, }, @@ -4131,6 +4009,7 @@ UniValue walletcreatefundedpsbt(const JSONRPCRequest& request) {"locktime", RPCArg::Type::NUM, /* default */ "0", "Raw locktime. Non-0 value also locktime-activates inputs"}, {"options", RPCArg::Type::OBJ, RPCArg::Optional::OMITTED_NAMED_ARG, "", { + {"add_inputs", RPCArg::Type::BOOL, /* default */ "false", "If inputs are specified, automatically include more if they are not enough."}, {"changeAddress", RPCArg::Type::STR_HEX, /* default */ "pool address", "The bitcoin address to receive the change"}, {"changePosition", RPCArg::Type::NUM, /* default */ "random", "The index of the change output"}, {"change_type", RPCArg::Type::STR, /* default */ "set by -changetype", "The output type to use. Only valid if changeAddress is not specified. Options are \"legacy\", \"p2sh-segwit\", and \"bech32\"."}, @@ -4148,10 +4027,8 @@ UniValue walletcreatefundedpsbt(const JSONRPCRequest& request) {"replaceable", RPCArg::Type::BOOL, /* default */ "wallet default", "Marks this transaction as BIP125 replaceable.\n" " Allows this transaction to be replaced by a transaction with higher fees"}, {"conf_target", RPCArg::Type::NUM, /* default */ "fall back to wallet's confirmation target (txconfirmtarget)", "Confirmation target (in blocks)"}, - {"estimate_mode", RPCArg::Type::STR, /* default */ "UNSET", "The fee estimate mode, must be one of:\n" - " \"UNSET\"\n" - " \"ECONOMICAL\"\n" - " \"CONSERVATIVE\""}, + {"estimate_mode", RPCArg::Type::STR, /* default */ "unset", std::string() + "The fee estimate mode, must be one of (case insensitive):\n" + " \"" + FeeModes("\"\n\"") + "\""}, }, "options"}, {"bip32derivs", RPCArg::Type::BOOL, /* default */ "true", "Include BIP 32 derivation paths for public keys if we know them"}, @@ -4170,6 +4047,10 @@ UniValue walletcreatefundedpsbt(const JSONRPCRequest& request) }, }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + RPCTypeCheck(request.params, { UniValue::VARR, UniValueType(), // ARR or OBJ, checked later @@ -4188,7 +4069,11 @@ UniValue walletcreatefundedpsbt(const JSONRPCRequest& request) rbf = replaceable_arg.isTrue(); } CMutableTransaction rawTx = ConstructTransaction(request.params[0], request.params[1], request.params[2], rbf); - FundTransaction(pwallet, rawTx, fee, change_position, request.params[3]); + CCoinControl coin_control; + // Automatically select coins, unless at least one is manually selected. Can + // be overridden by options.add_inputs. + coin_control.m_add_inputs = rawTx.vin.size() == 0; + FundTransaction(pwallet, rawTx, fee, change_position, request.params[3], coin_control); // Make a blank psbt PartiallySignedTransaction psbtx(rawTx); @@ -4214,13 +4099,6 @@ UniValue walletcreatefundedpsbt(const JSONRPCRequest& request) static UniValue upgradewallet(const JSONRPCRequest& request) { - std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); - CWallet* const pwallet = wallet.get(); - - if (!EnsureWalletIsAvailable(pwallet, request.fHelp)) { - return NullUniValue; - } - RPCHelpMan{"upgradewallet", "\nUpgrade the wallet. Upgrades to the latest version if no version number is specified\n" "New keys may be generated and a new wallet backup will need to be made.", @@ -4234,6 +4112,10 @@ static UniValue upgradewallet(const JSONRPCRequest& request) } }.Check(request); + std::shared_ptr<CWallet> const wallet = GetWalletForJSONRPCRequest(request); + if (!wallet) return NullUniValue; + CWallet* const pwallet = wallet.get(); + RPCTypeCheck(request.params, {UniValue::VNUM}, true); EnsureWalletIsUnlocked(pwallet); @@ -4263,7 +4145,7 @@ UniValue removeprunedfunds(const JSONRPCRequest& request); UniValue importmulti(const JSONRPCRequest& request); UniValue importdescriptors(const JSONRPCRequest& request); -void RegisterWalletRPCCommands(interfaces::Chain& chain, std::vector<std::unique_ptr<interfaces::Handler>>& handlers) +Span<const CRPCCommand> GetWalletRPCCommands() { // clang-format off static const CRPCCommand commands[] = @@ -4275,6 +4157,7 @@ static const CRPCCommand commands[] = { "wallet", "addmultisigaddress", &addmultisigaddress, {"nrequired","keys","label","address_type"} }, { "wallet", "backupwallet", &backupwallet, {"destination"} }, { "wallet", "bumpfee", &bumpfee, {"txid", "options"} }, + { "wallet", "psbtbumpfee", &psbtbumpfee, {"txid", "options"} }, { "wallet", "createwallet", &createwallet, {"wallet_name", "disable_private_keys", "blank", "passphrase", "avoid_reuse", "descriptors"} }, { "wallet", "dumpprivkey", &dumpprivkey, {"address"} }, { "wallet", "dumpwallet", &dumpwallet, {"filename"} }, @@ -4329,9 +4212,5 @@ static const CRPCCommand commands[] = { "wallet", "walletprocesspsbt", &walletprocesspsbt, {"psbt","sign","sighashtype","bip32derivs"} }, }; // clang-format on - - for (unsigned int vcidx = 0; vcidx < ARRAYLEN(commands); vcidx++) - handlers.emplace_back(chain.handleRpc(commands[vcidx])); + return MakeSpan(commands); } - -interfaces::Chain* g_rpc_chain = nullptr; diff --git a/src/wallet/rpcwallet.h b/src/wallet/rpcwallet.h index 8c149d455b..fb1e91282b 100644 --- a/src/wallet/rpcwallet.h +++ b/src/wallet/rpcwallet.h @@ -5,30 +5,22 @@ #ifndef BITCOIN_WALLET_RPCWALLET_H #define BITCOIN_WALLET_RPCWALLET_H +#include <span.h> + #include <memory> #include <string> #include <vector> -class CRPCTable; +class CRPCCommand; class CWallet; class JSONRPCRequest; class LegacyScriptPubKeyMan; class UniValue; -struct PartiallySignedTransaction; class CTransaction; +struct PartiallySignedTransaction; +struct WalletContext; -namespace interfaces { -class Chain; -class Handler; -} - -//! Pointer to chain interface that needs to be declared as a global to be -//! accessible loadwallet and createwallet methods. Due to limitations of the -//! RPC framework, there's currently no direct way to pass in state to RPC -//! methods without globals. -extern interfaces::Chain* g_rpc_chain; - -void RegisterWalletRPCCommands(interfaces::Chain& chain, std::vector<std::unique_ptr<interfaces::Handler>>& handlers); +Span<const CRPCCommand> GetWalletRPCCommands(); /** * Figures out what wallet, if any, to use for a JSONRPCRequest. @@ -39,7 +31,7 @@ void RegisterWalletRPCCommands(interfaces::Chain& chain, std::vector<std::unique std::shared_ptr<CWallet> GetWalletForJSONRPCRequest(const JSONRPCRequest& request); void EnsureWalletIsUnlocked(const CWallet*); -bool EnsureWalletIsAvailable(const CWallet*, bool avoidException); +WalletContext& EnsureWalletContext(const util::Ref& context); LegacyScriptPubKeyMan& EnsureLegacyScriptPubKeyMan(CWallet& wallet, bool also_create = false); UniValue getaddressinfo(const JSONRPCRequest& request); diff --git a/src/wallet/salvage.cpp b/src/wallet/salvage.cpp index 70067ebef0..af57210f01 100644 --- a/src/wallet/salvage.cpp +++ b/src/wallet/salvage.cpp @@ -5,6 +5,7 @@ #include <fs.h> #include <streams.h> +#include <util/translation.h> #include <wallet/salvage.h> #include <wallet/wallet.h> #include <wallet/walletdb.h> @@ -20,6 +21,12 @@ bool RecoverDatabaseFile(const fs::path& file_path) std::string filename; std::shared_ptr<BerkeleyEnvironment> env = GetWalletEnv(file_path, filename); + bilingual_str open_err; + if (!env->Open(open_err)) { + tfm::format(std::cerr, "%s\n", open_err.original); + return false; + } + // Recovery procedure: // move wallet file to walletfilename.timestamp.bak // Call Salvage with fAggressive=true to @@ -116,7 +123,7 @@ bool RecoverDatabaseFile(const fs::path& file_path) } DbTxn* ptxn = env->TxnBegin(); - CWallet dummyWallet(nullptr, WalletLocation(), WalletDatabase::CreateDummy()); + CWallet dummyWallet(nullptr, WalletLocation(), CreateDummyWalletDatabase()); for (KeyValPair& row : salvagedData) { /* Filter for only private key type KV pairs to be added to the salvaged wallet */ diff --git a/src/wallet/scriptpubkeyman.cpp b/src/wallet/scriptpubkeyman.cpp index 8a2a798644..51715462c5 100644 --- a/src/wallet/scriptpubkeyman.cpp +++ b/src/wallet/scriptpubkeyman.cpp @@ -88,16 +88,16 @@ IsMineResult IsMineInner(const LegacyScriptPubKeyMan& keystore, const CScript& s IsMineResult ret = IsMineResult::NO; std::vector<valtype> vSolutions; - txnouttype whichType = Solver(scriptPubKey, vSolutions); + TxoutType whichType = Solver(scriptPubKey, vSolutions); CKeyID keyID; switch (whichType) { - case TX_NONSTANDARD: - case TX_NULL_DATA: - case TX_WITNESS_UNKNOWN: + case TxoutType::NONSTANDARD: + case TxoutType::NULL_DATA: + case TxoutType::WITNESS_UNKNOWN: break; - case TX_PUBKEY: + case TxoutType::PUBKEY: keyID = CPubKey(vSolutions[0]).GetID(); if (!PermitsUncompressed(sigversion) && vSolutions[0].size() != 33) { return IsMineResult::INVALID; @@ -106,7 +106,7 @@ IsMineResult IsMineInner(const LegacyScriptPubKeyMan& keystore, const CScript& s ret = std::max(ret, IsMineResult::SPENDABLE); } break; - case TX_WITNESS_V0_KEYHASH: + case TxoutType::WITNESS_V0_KEYHASH: { if (sigversion == IsMineSigVersion::WITNESS_V0) { // P2WPKH inside P2WSH is invalid. @@ -121,7 +121,7 @@ IsMineResult IsMineInner(const LegacyScriptPubKeyMan& keystore, const CScript& s ret = std::max(ret, IsMineInner(keystore, GetScriptForDestination(PKHash(uint160(vSolutions[0]))), IsMineSigVersion::WITNESS_V0)); break; } - case TX_PUBKEYHASH: + case TxoutType::PUBKEYHASH: keyID = CKeyID(uint160(vSolutions[0])); if (!PermitsUncompressed(sigversion)) { CPubKey pubkey; @@ -133,7 +133,7 @@ IsMineResult IsMineInner(const LegacyScriptPubKeyMan& keystore, const CScript& s ret = std::max(ret, IsMineResult::SPENDABLE); } break; - case TX_SCRIPTHASH: + case TxoutType::SCRIPTHASH: { if (sigversion != IsMineSigVersion::TOP) { // P2SH inside P2WSH or P2SH is invalid. @@ -146,7 +146,7 @@ IsMineResult IsMineInner(const LegacyScriptPubKeyMan& keystore, const CScript& s } break; } - case TX_WITNESS_V0_SCRIPTHASH: + case TxoutType::WITNESS_V0_SCRIPTHASH: { if (sigversion == IsMineSigVersion::WITNESS_V0) { // P2WSH inside P2WSH is invalid. @@ -165,7 +165,7 @@ IsMineResult IsMineInner(const LegacyScriptPubKeyMan& keystore, const CScript& s break; } - case TX_MULTISIG: + case TxoutType::MULTISIG: { // Never treat bare multisig outputs as ours (they can still be made watchonly-though) if (sigversion == IsMineSigVersion::TOP) { @@ -573,9 +573,8 @@ bool LegacyScriptPubKeyMan::SignTransaction(CMutableTransaction& tx, const std:: SigningResult LegacyScriptPubKeyMan::SignMessage(const std::string& message, const PKHash& pkhash, std::string& str_sig) const { - CKeyID key_id(pkhash); CKey key; - if (!GetKey(key_id, key)) { + if (!GetKey(ToKeyID(pkhash), key)) { return SigningResult::PRIVATE_KEY_NOT_AVAILABLE; } @@ -585,8 +584,11 @@ SigningResult LegacyScriptPubKeyMan::SignMessage(const std::string& message, con return SigningResult::SIGNING_FAILED; } -TransactionError LegacyScriptPubKeyMan::FillPSBT(PartiallySignedTransaction& psbtx, int sighash_type, bool sign, bool bip32derivs) const +TransactionError LegacyScriptPubKeyMan::FillPSBT(PartiallySignedTransaction& psbtx, int sighash_type, bool sign, bool bip32derivs, int* n_signed) const { + if (n_signed) { + *n_signed = 0; + } for (unsigned int i = 0; i < psbtx.tx->vin.size(); ++i) { const CTxIn& txin = psbtx.tx->vin[i]; PSBTInput& input = psbtx.inputs.at(i); @@ -595,11 +597,6 @@ TransactionError LegacyScriptPubKeyMan::FillPSBT(PartiallySignedTransaction& psb continue; } - // Verify input looks sane. This will check that we have at most one uxto, witness or non-witness. - if (!input.IsSane()) { - return TransactionError::INVALID_PSBT; - } - // Get the Sighash type if (sign && input.sighash_type > 0 && input.sighash_type != sighash_type) { return TransactionError::SIGHASH_MISMATCH; @@ -617,6 +614,14 @@ TransactionError LegacyScriptPubKeyMan::FillPSBT(PartiallySignedTransaction& psb SignatureData sigdata; input.FillSignatureData(sigdata); SignPSBTInput(HidingSigningProvider(this, !sign, !bip32derivs), psbtx, i, sighash_type); + + bool signed_one = PSBTInputSigned(input); + if (n_signed && (signed_one || !sign)) { + // If sign is false, we assume that we _could_ sign if we get here. This + // will never have false negatives; it is hard to tell under what i + // circumstances it could have false positives. + (*n_signed)++; + } } // Fill in the bip32 keypaths and redeemscripts for the outputs so that hardware wallets can identify change @@ -826,7 +831,7 @@ bool LegacyScriptPubKeyMan::HaveWatchOnly() const static bool ExtractPubKey(const CScript &dest, CPubKey& pubKeyOut) { std::vector<std::vector<unsigned char>> solutions; - return Solver(dest, solutions) == TX_PUBKEY && + return Solver(dest, solutions) == TxoutType::PUBKEY && (pubKeyOut = CPubKey(solutions[0])).IsFullyValid(); } @@ -900,20 +905,22 @@ bool LegacyScriptPubKeyMan::AddWatchOnly(const CScript& dest, int64_t nCreateTim return AddWatchOnly(dest); } -void LegacyScriptPubKeyMan::SetHDChain(const CHDChain& chain, bool memonly) +void LegacyScriptPubKeyMan::LoadHDChain(const CHDChain& chain) { LOCK(cs_KeyStore); - // memonly == true means we are loading the wallet file - // memonly == false means that the chain is actually being changed - if (!memonly) { - // Store the new chain - if (!WalletBatch(m_storage.GetDatabase()).WriteHDChain(chain)) { - throw std::runtime_error(std::string(__func__) + ": writing chain failed"); - } - // When there's an old chain, add it as an inactive chain as we are now rotating hd chains - if (!m_hd_chain.seed_id.IsNull()) { - AddInactiveHDChain(m_hd_chain); - } + m_hd_chain = chain; +} + +void LegacyScriptPubKeyMan::AddHDChain(const CHDChain& chain) +{ + LOCK(cs_KeyStore); + // Store the new chain + if (!WalletBatch(m_storage.GetDatabase()).WriteHDChain(chain)) { + throw std::runtime_error(std::string(__func__) + ": writing chain failed"); + } + // When there's an old chain, add it as an inactive chain as we are now rotating hd chains + if (!m_hd_chain.seed_id.IsNull()) { + AddInactiveHDChain(m_hd_chain); } m_hd_chain = chain; @@ -1167,7 +1174,7 @@ void LegacyScriptPubKeyMan::SetHDSeed(const CPubKey& seed) CHDChain newHdChain; newHdChain.nVersion = m_storage.CanSupportFeature(FEATURE_HD_SPLIT) ? CHDChain::VERSION_HD_CHAIN_SPLIT : CHDChain::VERSION_HD_BASE; newHdChain.seed_id = seed.GetID(); - SetHDChain(newHdChain, false); + AddHDChain(newHdChain); NotifyCanGetAddressesChanged(); WalletBatch batch(m_storage.GetDatabase()); m_storage.UnsetBlankWalletFlag(batch); @@ -1890,8 +1897,8 @@ bool DescriptorScriptPubKeyMan::SetupDescriptorGeneration(const CExtKey& master_ desc_prefix = "wpkh(" + xpub + "/84'"; break; } - default: assert(false); - } + } // no default case, so the compiler can warn about missing cases + assert(!desc_prefix.empty()); // Mainnet derives at 0', testnet and regtest derive at 1' if (Params().IsTestChain()) { @@ -2052,9 +2059,8 @@ SigningResult DescriptorScriptPubKeyMan::SignMessage(const std::string& message, return SigningResult::PRIVATE_KEY_NOT_AVAILABLE; } - CKeyID key_id(pkhash); CKey key; - if (!keys->GetKey(key_id, key)) { + if (!keys->GetKey(ToKeyID(pkhash), key)) { return SigningResult::PRIVATE_KEY_NOT_AVAILABLE; } @@ -2064,8 +2070,11 @@ SigningResult DescriptorScriptPubKeyMan::SignMessage(const std::string& message, return SigningResult::OK; } -TransactionError DescriptorScriptPubKeyMan::FillPSBT(PartiallySignedTransaction& psbtx, int sighash_type, bool sign, bool bip32derivs) const +TransactionError DescriptorScriptPubKeyMan::FillPSBT(PartiallySignedTransaction& psbtx, int sighash_type, bool sign, bool bip32derivs, int* n_signed) const { + if (n_signed) { + *n_signed = 0; + } for (unsigned int i = 0; i < psbtx.tx->vin.size(); ++i) { const CTxIn& txin = psbtx.tx->vin[i]; PSBTInput& input = psbtx.inputs.at(i); @@ -2074,11 +2083,6 @@ TransactionError DescriptorScriptPubKeyMan::FillPSBT(PartiallySignedTransaction& continue; } - // Verify input looks sane. This will check that we have at most one uxto, witness or non-witness. - if (!input.IsSane()) { - return TransactionError::INVALID_PSBT; - } - // Get the Sighash type if (sign && input.sighash_type > 0 && input.sighash_type != sighash_type) { return TransactionError::SIGHASH_MISMATCH; @@ -2117,6 +2121,14 @@ TransactionError DescriptorScriptPubKeyMan::FillPSBT(PartiallySignedTransaction& } SignPSBTInput(HidingSigningProvider(keys.get(), !sign, !bip32derivs), psbtx, i, sighash_type); + + bool signed_one = PSBTInputSigned(input); + if (n_signed && (signed_one || !sign)) { + // If sign is false, we assume that we _could_ sign if we get here. This + // will never have false negatives; it is hard to tell under what i + // circumstances it could have false positives. + (*n_signed)++; + } } // Fill in the bip32 keypaths and redeemscripts for the outputs so that hardware wallets can identify change diff --git a/src/wallet/scriptpubkeyman.h b/src/wallet/scriptpubkeyman.h index d62d30f339..a96d971734 100644 --- a/src/wallet/scriptpubkeyman.h +++ b/src/wallet/scriptpubkeyman.h @@ -234,7 +234,7 @@ public: /** Sign a message with the given script */ virtual SigningResult SignMessage(const std::string& message, const PKHash& pkhash, std::string& str_sig) const { return SigningResult::SIGNING_FAILED; }; /** Adds script and derivation path information to a PSBT, and optionally signs it. */ - virtual TransactionError FillPSBT(PartiallySignedTransaction& psbt, int sighash_type = 1 /* SIGHASH_ALL */, bool sign = true, bool bip32derivs = false) const { return TransactionError::INVALID_PSBT; } + virtual TransactionError FillPSBT(PartiallySignedTransaction& psbt, int sighash_type = 1 /* SIGHASH_ALL */, bool sign = true, bool bip32derivs = false, int* n_signed = nullptr) const { return TransactionError::INVALID_PSBT; } virtual uint256 GetID() const { return uint256(); } @@ -393,7 +393,7 @@ public: bool SignTransaction(CMutableTransaction& tx, const std::map<COutPoint, Coin>& coins, int sighash, std::map<int, std::string>& input_errors) const override; SigningResult SignMessage(const std::string& message, const PKHash& pkhash, std::string& str_sig) const override; - TransactionError FillPSBT(PartiallySignedTransaction& psbt, int sighash_type = 1 /* SIGHASH_ALL */, bool sign = true, bool bip32derivs = false) const override; + TransactionError FillPSBT(PartiallySignedTransaction& psbt, int sighash_type = 1 /* SIGHASH_ALL */, bool sign = true, bool bip32derivs = false, int* n_signed = nullptr) const override; uint256 GetID() const override; @@ -422,8 +422,10 @@ public: //! Generate a new key CPubKey GenerateNewKey(WalletBatch& batch, CHDChain& hd_chain, bool internal = false) EXCLUSIVE_LOCKS_REQUIRED(cs_KeyStore); - /* Set the HD chain model (chain child index counters) */ - void SetHDChain(const CHDChain& chain, bool memonly); + /* Set the HD chain model (chain child index counters) and writes it to the database */ + void AddHDChain(const CHDChain& chain); + //! Load a HD chain model (used by LoadWallet) + void LoadHDChain(const CHDChain& chain); const CHDChain& GetHDChain() const { return m_hd_chain; } void AddInactiveHDChain(const CHDChain& chain); @@ -596,7 +598,7 @@ public: bool SignTransaction(CMutableTransaction& tx, const std::map<COutPoint, Coin>& coins, int sighash, std::map<int, std::string>& input_errors) const override; SigningResult SignMessage(const std::string& message, const PKHash& pkhash, std::string& str_sig) const override; - TransactionError FillPSBT(PartiallySignedTransaction& psbt, int sighash_type = 1 /* SIGHASH_ALL */, bool sign = true, bool bip32derivs = false) const override; + TransactionError FillPSBT(PartiallySignedTransaction& psbt, int sighash_type = 1 /* SIGHASH_ALL */, bool sign = true, bool bip32derivs = false, int* n_signed = nullptr) const override; uint256 GetID() const override; diff --git a/src/wallet/test/coinselector_tests.cpp b/src/wallet/test/coinselector_tests.cpp index 657d0828f2..1deedede4c 100644 --- a/src/wallet/test/coinselector_tests.cpp +++ b/src/wallet/test/coinselector_tests.cpp @@ -29,7 +29,7 @@ typedef std::set<CInputCoin> CoinSet; static std::vector<COutput> vCoins; static NodeContext testNode; static auto testChain = interfaces::MakeChain(testNode); -static CWallet testWallet(testChain.get(), WalletLocation(), WalletDatabase::CreateDummy()); +static CWallet testWallet(testChain.get(), WalletLocation(), CreateDummyWalletDatabase()); static CAmount balance = 0; CoinEligibilityFilter filter_standard(1, 6, 0); @@ -283,7 +283,7 @@ BOOST_AUTO_TEST_CASE(bnb_search_test) // Make sure that can use BnB when there are preset inputs empty_wallet(); { - std::unique_ptr<CWallet> wallet = MakeUnique<CWallet>(m_chain.get(), WalletLocation(), WalletDatabase::CreateMock()); + std::unique_ptr<CWallet> wallet = MakeUnique<CWallet>(m_chain.get(), WalletLocation(), CreateMockWalletDatabase()); bool firstRun; wallet->LoadWallet(firstRun); wallet->SetupLegacyScriptPubKeyMan(); diff --git a/src/wallet/test/db_tests.cpp b/src/wallet/test/db_tests.cpp index f4a4c9fa7c..8f0083cd2e 100644 --- a/src/wallet/test/db_tests.cpp +++ b/src/wallet/test/db_tests.cpp @@ -8,7 +8,7 @@ #include <fs.h> #include <test/util/setup_common.h> -#include <wallet/db.h> +#include <wallet/bdb.h> BOOST_FIXTURE_TEST_SUITE(db_tests, BasicTestingSetup) diff --git a/src/wallet/test/init_test_fixture.cpp b/src/wallet/test/init_test_fixture.cpp index 797a0d634f..35bd965673 100644 --- a/src/wallet/test/init_test_fixture.cpp +++ b/src/wallet/test/init_test_fixture.cpp @@ -3,13 +3,14 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include <fs.h> +#include <util/check.h> #include <util/system.h> #include <wallet/test/init_test_fixture.h> -InitWalletDirTestingSetup::InitWalletDirTestingSetup(const std::string& chainName): BasicTestingSetup(chainName) +InitWalletDirTestingSetup::InitWalletDirTestingSetup(const std::string& chainName) : BasicTestingSetup(chainName) { - m_chain_client = MakeWalletClient(*m_chain, {}); + m_chain_client = MakeWalletClient(*m_chain, *Assert(m_node.args), {}); std::string sep; sep += fs::path::preferred_separator; diff --git a/src/wallet/test/init_test_fixture.h b/src/wallet/test/init_test_fixture.h index 6ba7d66b7c..c95b4f1f6e 100644 --- a/src/wallet/test/init_test_fixture.h +++ b/src/wallet/test/init_test_fixture.h @@ -18,7 +18,6 @@ struct InitWalletDirTestingSetup: public BasicTestingSetup { fs::path m_datadir; fs::path m_cwd; std::map<std::string, fs::path> m_walletdir_path_cases; - NodeContext m_node; std::unique_ptr<interfaces::Chain> m_chain = interfaces::MakeChain(m_node); std::unique_ptr<interfaces::ChainClient> m_chain_client; }; diff --git a/src/wallet/test/ismine_tests.cpp b/src/wallet/test/ismine_tests.cpp index 4c0e4dc653..e416f16044 100644 --- a/src/wallet/test/ismine_tests.cpp +++ b/src/wallet/test/ismine_tests.cpp @@ -35,7 +35,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2PK compressed { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); scriptPubKey = GetScriptForRawPubKey(pubkeys[0]); @@ -52,7 +52,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2PK uncompressed { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); scriptPubKey = GetScriptForRawPubKey(uncompressedPubkey); @@ -69,7 +69,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2PKH compressed { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); scriptPubKey = GetScriptForDestination(PKHash(pubkeys[0])); @@ -86,7 +86,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2PKH uncompressed { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); scriptPubKey = GetScriptForDestination(PKHash(uncompressedPubkey)); @@ -103,7 +103,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2SH { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); @@ -127,7 +127,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // (P2PKH inside) P2SH inside P2SH (invalid) { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); @@ -145,7 +145,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // (P2PKH inside) P2SH inside P2WSH (invalid) { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); @@ -163,11 +163,11 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2WPKH inside P2WSH (invalid) { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); - CScript witnessscript = GetScriptForDestination(WitnessV0KeyHash(PKHash(pubkeys[0]))); + CScript witnessscript = GetScriptForDestination(WitnessV0KeyHash(pubkeys[0])); scriptPubKey = GetScriptForDestination(WitnessV0ScriptHash(witnessscript)); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddCScript(witnessscript)); @@ -179,7 +179,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // (P2PKH inside) P2WSH inside P2WSH (invalid) { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); @@ -197,12 +197,12 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2WPKH compressed { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddKey(keys[0])); - scriptPubKey = GetScriptForDestination(WitnessV0KeyHash(PKHash(pubkeys[0]))); + scriptPubKey = GetScriptForDestination(WitnessV0KeyHash(pubkeys[0])); // Keystore implicitly has key and P2SH redeemScript BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddCScript(scriptPubKey)); @@ -212,12 +212,12 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2WPKH uncompressed { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddKey(uncompressedKey)); - scriptPubKey = GetScriptForDestination(WitnessV0KeyHash(PKHash(uncompressedPubkey))); + scriptPubKey = GetScriptForDestination(WitnessV0KeyHash(uncompressedPubkey)); // Keystore has key, but no P2SH redeemScript result = keystore.GetLegacyScriptPubKeyMan()->IsMine(scriptPubKey); @@ -231,7 +231,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // scriptPubKey multisig { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); @@ -262,7 +262,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2SH multisig { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddKey(uncompressedKey)); @@ -283,7 +283,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2WSH multisig with compressed keys { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddKey(keys[0])); @@ -309,7 +309,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2WSH multisig with uncompressed key { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddKey(uncompressedKey)); @@ -335,7 +335,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // P2WSH multisig wrapped in P2SH { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); @@ -362,7 +362,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // OP_RETURN { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddKey(keys[0])); @@ -376,7 +376,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // witness unspendable { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddKey(keys[0])); @@ -390,7 +390,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // witness unknown { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddKey(keys[0])); @@ -404,7 +404,7 @@ BOOST_AUTO_TEST_CASE(ismine_standard) // Nonstandard { - CWallet keystore(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet keystore(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); keystore.SetupLegacyScriptPubKeyMan(); LOCK(keystore.GetLegacyScriptPubKeyMan()->cs_KeyStore); BOOST_CHECK(keystore.GetLegacyScriptPubKeyMan()->AddKey(keys[0])); diff --git a/src/wallet/test/psbt_wallet_tests.cpp b/src/wallet/test/psbt_wallet_tests.cpp index b4c65a8665..ce7e661b67 100644 --- a/src/wallet/test/psbt_wallet_tests.cpp +++ b/src/wallet/test/psbt_wallet_tests.cpp @@ -63,8 +63,8 @@ BOOST_AUTO_TEST_CASE(psbt_updater_test) // Get the final tx CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION); ssTx << psbtx; - std::string final_hex = HexStr(ssTx.begin(), ssTx.end()); - BOOST_CHECK_EQUAL(final_hex, "70736274ff01009a020000000258e87a21b56daf0c23be8e7070456c336f7cbaa5c8757924f545887bb2abdd750000000000ffffffff838d0427d0ec650a68aa46bb0b098aea4422c071b2ca78352a077959d07cea1d0100000000ffffffff0270aaf00800000000160014d85c2b71d0060b09c9886aeb815e50991dda124d00e1f5050000000016001400aea9a2e5f0f876a588df5546e8742d1d87008f00000000000100bb0200000001aad73931018bd25f84ae400b68848be09db706eac2ac18298babee71ab656f8b0000000048473044022058f6fc7c6a33e1b31548d481c826c015bd30135aad42cd67790dab66d2ad243b02204a1ced2604c6735b6393e5b41691dd78b00f0c5942fb9f751856faa938157dba01feffffff0280f0fa020000000017a9140fb9463421696b82c833af241c78c17ddbde493487d0f20a270100000017a91429ca74f8a08f81999428185c97b5d852e4063f6187650000000104475221029583bf39ae0a609747ad199addd634fa6108559d6c5cd39b4c2183f1ab96e07f2102dab61ff49a14db6a7d02b0cd1fbb78fc4b18312b5b4e54dae4dba2fbfef536d752ae2206029583bf39ae0a609747ad199addd634fa6108559d6c5cd39b4c2183f1ab96e07f10d90c6a4f000000800000008000000080220602dab61ff49a14db6a7d02b0cd1fbb78fc4b18312b5b4e54dae4dba2fbfef536d710d90c6a4f0000008000000080010000800001012000c2eb0b0000000017a914b7f5faf40e3d40a5a459b1db3535f2b72fa921e88701042200208c2353173743b595dfb4a07b72ba8e42e3797da74e87fe7d9d7497e3b2028903010547522103089dc10c7ac6db54f91329af617333db388cead0c231f723379d1b99030b02dc21023add904f3d6dcf59ddb906b0dee23529b7ffb9ed50e5e86151926860221f0e7352ae2206023add904f3d6dcf59ddb906b0dee23529b7ffb9ed50e5e86151926860221f0e7310d90c6a4f000000800000008003000080220603089dc10c7ac6db54f91329af617333db388cead0c231f723379d1b99030b02dc10d90c6a4f00000080000000800200008000220203a9a4c37f5996d3aa25dbac6b570af0650394492942460b354753ed9eeca5877110d90c6a4f000000800000008004000080002202027f6399757d2eff55a136ad02c684b1838b6556e5f1b6b34282a94b6b5005109610d90c6a4f00000080000000800500008000"); + std::string final_hex = HexStr(ssTx); + BOOST_CHECK_EQUAL(final_hex, "70736274ff01009a020000000258e87a21b56daf0c23be8e7070456c336f7cbaa5c8757924f545887bb2abdd750000000000ffffffff838d0427d0ec650a68aa46bb0b098aea4422c071b2ca78352a077959d07cea1d0100000000ffffffff0270aaf00800000000160014d85c2b71d0060b09c9886aeb815e50991dda124d00e1f5050000000016001400aea9a2e5f0f876a588df5546e8742d1d87008f00000000000100bb0200000001aad73931018bd25f84ae400b68848be09db706eac2ac18298babee71ab656f8b0000000048473044022058f6fc7c6a33e1b31548d481c826c015bd30135aad42cd67790dab66d2ad243b02204a1ced2604c6735b6393e5b41691dd78b00f0c5942fb9f751856faa938157dba01feffffff0280f0fa020000000017a9140fb9463421696b82c833af241c78c17ddbde493487d0f20a270100000017a91429ca74f8a08f81999428185c97b5d852e4063f6187650000000104475221029583bf39ae0a609747ad199addd634fa6108559d6c5cd39b4c2183f1ab96e07f2102dab61ff49a14db6a7d02b0cd1fbb78fc4b18312b5b4e54dae4dba2fbfef536d752ae2206029583bf39ae0a609747ad199addd634fa6108559d6c5cd39b4c2183f1ab96e07f10d90c6a4f000000800000008000000080220602dab61ff49a14db6a7d02b0cd1fbb78fc4b18312b5b4e54dae4dba2fbfef536d710d90c6a4f0000008000000080010000800001008a020000000158e87a21b56daf0c23be8e7070456c336f7cbaa5c8757924f545887bb2abdd7501000000171600145f275f436b09a8cc9a2eb2a2f528485c68a56323feffffff02d8231f1b0100000017a914aed962d6654f9a2b36608eb9d64d2b260db4f1118700c2eb0b0000000017a914b7f5faf40e3d40a5a459b1db3535f2b72fa921e8876500000001012000c2eb0b0000000017a914b7f5faf40e3d40a5a459b1db3535f2b72fa921e88701042200208c2353173743b595dfb4a07b72ba8e42e3797da74e87fe7d9d7497e3b2028903010547522103089dc10c7ac6db54f91329af617333db388cead0c231f723379d1b99030b02dc21023add904f3d6dcf59ddb906b0dee23529b7ffb9ed50e5e86151926860221f0e7352ae2206023add904f3d6dcf59ddb906b0dee23529b7ffb9ed50e5e86151926860221f0e7310d90c6a4f000000800000008003000080220603089dc10c7ac6db54f91329af617333db388cead0c231f723379d1b99030b02dc10d90c6a4f00000080000000800200008000220203a9a4c37f5996d3aa25dbac6b570af0650394492942460b354753ed9eeca5877110d90c6a4f000000800000008004000080002202027f6399757d2eff55a136ad02c684b1838b6556e5f1b6b34282a94b6b5005109610d90c6a4f00000080000000800500008000"); // Mutate the transaction so that one of the inputs is invalid psbtx.tx->vin[0].prevout.n = 2; diff --git a/src/wallet/test/scriptpubkeyman_tests.cpp b/src/wallet/test/scriptpubkeyman_tests.cpp index 757865ea37..4f12079768 100644 --- a/src/wallet/test/scriptpubkeyman_tests.cpp +++ b/src/wallet/test/scriptpubkeyman_tests.cpp @@ -19,7 +19,7 @@ BOOST_AUTO_TEST_CASE(CanProvide) // Set up wallet and keyman variables. NodeContext node; std::unique_ptr<interfaces::Chain> chain = interfaces::MakeChain(node); - CWallet wallet(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet wallet(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); LegacyScriptPubKeyMan& keyman = *wallet.GetOrCreateLegacyScriptPubKeyMan(); // Make a 1 of 2 multisig script diff --git a/src/wallet/test/wallet_crypto_tests.cpp b/src/wallet/test/wallet_crypto_tests.cpp index 97f8c94fa6..10ddfa22ef 100644 --- a/src/wallet/test/wallet_crypto_tests.cpp +++ b/src/wallet/test/wallet_crypto_tests.cpp @@ -24,10 +24,10 @@ static void TestPassphraseSingle(const std::vector<unsigned char>& vchSalt, cons if(!correctKey.empty()) BOOST_CHECK_MESSAGE(memcmp(crypt.vchKey.data(), correctKey.data(), crypt.vchKey.size()) == 0, \ - HexStr(crypt.vchKey.begin(), crypt.vchKey.end()) + std::string(" != ") + HexStr(correctKey.begin(), correctKey.end())); + HexStr(crypt.vchKey) + std::string(" != ") + HexStr(correctKey)); if(!correctIV.empty()) BOOST_CHECK_MESSAGE(memcmp(crypt.vchIV.data(), correctIV.data(), crypt.vchIV.size()) == 0, - HexStr(crypt.vchIV.begin(), crypt.vchIV.end()) + std::string(" != ") + HexStr(correctIV.begin(), correctIV.end())); + HexStr(crypt.vchIV) + std::string(" != ") + HexStr(correctIV)); } static void TestPassphrase(const std::vector<unsigned char>& vchSalt, const SecureString& passphrase, uint32_t rounds, diff --git a/src/wallet/test/wallet_test_fixture.cpp b/src/wallet/test/wallet_test_fixture.cpp index 7ba3148ff3..44f9eb5ddd 100644 --- a/src/wallet/test/wallet_test_fixture.cpp +++ b/src/wallet/test/wallet_test_fixture.cpp @@ -6,7 +6,7 @@ WalletTestingSetup::WalletTestingSetup(const std::string& chainName) : TestingSetup(chainName), - m_wallet(m_chain.get(), WalletLocation(), WalletDatabase::CreateMock()) + m_wallet(m_chain.get(), WalletLocation(), CreateMockWalletDatabase()) { bool fFirstRun; m_wallet.LoadWallet(fFirstRun); diff --git a/src/wallet/test/wallet_test_fixture.h b/src/wallet/test/wallet_test_fixture.h index a294935b64..99d7cfe921 100644 --- a/src/wallet/test/wallet_test_fixture.h +++ b/src/wallet/test/wallet_test_fixture.h @@ -10,18 +10,18 @@ #include <interfaces/chain.h> #include <interfaces/wallet.h> #include <node/context.h> +#include <util/check.h> #include <wallet/wallet.h> #include <memory> /** Testing setup and teardown for wallet. */ -struct WalletTestingSetup: public TestingSetup { +struct WalletTestingSetup : public TestingSetup { explicit WalletTestingSetup(const std::string& chainName = CBaseChainParams::MAIN); - NodeContext m_node; std::unique_ptr<interfaces::Chain> m_chain = interfaces::MakeChain(m_node); - std::unique_ptr<interfaces::ChainClient> m_chain_client = interfaces::MakeWalletClient(*m_chain, {}); + std::unique_ptr<interfaces::ChainClient> m_chain_client = interfaces::MakeWalletClient(*m_chain, *Assert(m_node.args), {}); CWallet m_wallet; std::unique_ptr<interfaces::Handler> m_chain_notifications_handler; }; diff --git a/src/wallet/test/wallet_tests.cpp b/src/wallet/test/wallet_tests.cpp index 3654420eb2..7ef06663b5 100644 --- a/src/wallet/test/wallet_tests.cpp +++ b/src/wallet/test/wallet_tests.cpp @@ -28,6 +28,11 @@ extern UniValue importmulti(const JSONRPCRequest& request); extern UniValue dumpwallet(const JSONRPCRequest& request); extern UniValue importwallet(const JSONRPCRequest& request); +// Ensure that fee levels defined in the wallet are at least as high +// as the default levels for node policy. +static_assert(DEFAULT_TRANSACTION_MINFEE >= DEFAULT_MIN_RELAY_TX_FEE, "wallet minimum fee is smaller than default relay fee"); +static_assert(WALLET_INCREMENTAL_RELAY_FEE >= DEFAULT_INCREMENTAL_RELAY_FEE, "wallet incremental fee is smaller than default incremental relay fee"); + BOOST_FIXTURE_TEST_SUITE(wallet_tests, WalletTestingSetup) static std::shared_ptr<CWallet> TestLoadWallet(interfaces::Chain& chain) @@ -80,7 +85,7 @@ BOOST_FIXTURE_TEST_CASE(scan_for_wallet_transactions, TestChain100Setup) // Verify ScanForWalletTransactions fails to read an unknown start block. { - CWallet wallet(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet wallet(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); { LOCK(wallet.cs_wallet); wallet.SetLastBlockProcessed(::ChainActive().Height(), ::ChainActive().Tip()->GetBlockHash()); @@ -99,7 +104,7 @@ BOOST_FIXTURE_TEST_CASE(scan_for_wallet_transactions, TestChain100Setup) // Verify ScanForWalletTransactions picks up transactions in both the old // and new block files. { - CWallet wallet(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet wallet(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); { LOCK(wallet.cs_wallet); wallet.SetLastBlockProcessed(::ChainActive().Height(), ::ChainActive().Tip()->GetBlockHash()); @@ -118,14 +123,14 @@ BOOST_FIXTURE_TEST_CASE(scan_for_wallet_transactions, TestChain100Setup) // Prune the older block file. { LOCK(cs_main); - EnsureChainman(m_node).PruneOneBlockFile(oldTip->GetBlockPos().nFile); + Assert(m_node.chainman)->PruneOneBlockFile(oldTip->GetBlockPos().nFile); } UnlinkPrunedFiles({oldTip->GetBlockPos().nFile}); // Verify ScanForWalletTransactions only picks transactions in the new block // file. { - CWallet wallet(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet wallet(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); { LOCK(wallet.cs_wallet); wallet.SetLastBlockProcessed(::ChainActive().Height(), ::ChainActive().Tip()->GetBlockHash()); @@ -144,13 +149,13 @@ BOOST_FIXTURE_TEST_CASE(scan_for_wallet_transactions, TestChain100Setup) // Prune the remaining block file. { LOCK(cs_main); - EnsureChainman(m_node).PruneOneBlockFile(newTip->GetBlockPos().nFile); + Assert(m_node.chainman)->PruneOneBlockFile(newTip->GetBlockPos().nFile); } UnlinkPrunedFiles({newTip->GetBlockPos().nFile}); // Verify ScanForWalletTransactions scans no blocks. { - CWallet wallet(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet wallet(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); { LOCK(wallet.cs_wallet); wallet.SetLastBlockProcessed(::ChainActive().Height(), ::ChainActive().Tip()->GetBlockHash()); @@ -181,7 +186,7 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) // Prune the older block file. { LOCK(cs_main); - EnsureChainman(m_node).PruneOneBlockFile(oldTip->GetBlockPos().nFile); + Assert(m_node.chainman)->PruneOneBlockFile(oldTip->GetBlockPos().nFile); } UnlinkPrunedFiles({oldTip->GetBlockPos().nFile}); @@ -189,7 +194,7 @@ BOOST_FIXTURE_TEST_CASE(importmulti_rescan, TestChain100Setup) // before the missing block, and success for a key whose creation time is // after. { - std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); wallet->SetupLegacyScriptPubKeyMan(); WITH_LOCK(wallet->cs_wallet, wallet->SetLastBlockProcessed(newTip->nHeight, newTip->GetBlockHash())); AddWallet(wallet); @@ -254,7 +259,7 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) // Import key into wallet and call dumpwallet to create backup file. { - std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); { auto spk_man = wallet->GetOrCreateLegacyScriptPubKeyMan(); LOCK2(wallet->cs_wallet, spk_man->cs_KeyStore); @@ -276,7 +281,7 @@ BOOST_FIXTURE_TEST_CASE(importwallet_rescan, TestChain100Setup) // Call importwallet RPC and verify all blocks with timestamps >= BLOCK_TIME // were scanned, and no prior blocks were scanned. { - std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); LOCK(wallet->cs_wallet); wallet->SetupLegacyScriptPubKeyMan(); @@ -312,7 +317,7 @@ BOOST_FIXTURE_TEST_CASE(coin_mark_dirty_immature_credit, TestChain100Setup) NodeContext node; auto chain = interfaces::MakeChain(node); - CWallet wallet(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + CWallet wallet(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); auto spk_man = wallet.GetOrCreateLegacyScriptPubKeyMan(); CWalletTx wtx(&wallet, m_coinbase_txns.back()); @@ -333,7 +338,7 @@ BOOST_FIXTURE_TEST_CASE(coin_mark_dirty_immature_credit, TestChain100Setup) BOOST_CHECK_EQUAL(wtx.GetImmatureCredit(), 50*COIN); } -static int64_t AddTx(CWallet& wallet, uint32_t lockTime, int64_t mockTime, int64_t blockTime) +static int64_t AddTx(ChainstateManager& chainman, CWallet& wallet, uint32_t lockTime, int64_t mockTime, int64_t blockTime) { CMutableTransaction tx; CWalletTx::Confirmation confirm; @@ -341,7 +346,8 @@ static int64_t AddTx(CWallet& wallet, uint32_t lockTime, int64_t mockTime, int64 SetMockTime(mockTime); CBlockIndex* block = nullptr; if (blockTime > 0) { - auto inserted = ::BlockIndex().emplace(GetRandHash(), new CBlockIndex); + LOCK(cs_main); + auto inserted = chainman.BlockIndex().emplace(GetRandHash(), new CBlockIndex); assert(inserted.second); const uint256& hash = inserted.first->first; block = inserted.first->second; @@ -363,24 +369,24 @@ static int64_t AddTx(CWallet& wallet, uint32_t lockTime, int64_t mockTime, int64 BOOST_AUTO_TEST_CASE(ComputeTimeSmart) { // New transaction should use clock time if lower than block time. - BOOST_CHECK_EQUAL(AddTx(m_wallet, 1, 100, 120), 100); + BOOST_CHECK_EQUAL(AddTx(*m_node.chainman, m_wallet, 1, 100, 120), 100); // Test that updating existing transaction does not change smart time. - BOOST_CHECK_EQUAL(AddTx(m_wallet, 1, 200, 220), 100); + BOOST_CHECK_EQUAL(AddTx(*m_node.chainman, m_wallet, 1, 200, 220), 100); // New transaction should use clock time if there's no block time. - BOOST_CHECK_EQUAL(AddTx(m_wallet, 2, 300, 0), 300); + BOOST_CHECK_EQUAL(AddTx(*m_node.chainman, m_wallet, 2, 300, 0), 300); // New transaction should use block time if lower than clock time. - BOOST_CHECK_EQUAL(AddTx(m_wallet, 3, 420, 400), 400); + BOOST_CHECK_EQUAL(AddTx(*m_node.chainman, m_wallet, 3, 420, 400), 400); // New transaction should use latest entry time if higher than // min(block time, clock time). - BOOST_CHECK_EQUAL(AddTx(m_wallet, 4, 500, 390), 400); + BOOST_CHECK_EQUAL(AddTx(*m_node.chainman, m_wallet, 4, 500, 390), 400); // If there are future entries, new transaction should use time of the // newest entry that is no more than 300 seconds ahead of the clock time. - BOOST_CHECK_EQUAL(AddTx(m_wallet, 5, 50, 600), 300); + BOOST_CHECK_EQUAL(AddTx(*m_node.chainman, m_wallet, 5, 50, 600), 300); // Reset mock time for other tests. SetMockTime(0); @@ -486,7 +492,7 @@ public: ListCoinsTestingSetup() { CreateAndProcessBlock({}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())); - wallet = MakeUnique<CWallet>(m_chain.get(), WalletLocation(), WalletDatabase::CreateMock()); + wallet = MakeUnique<CWallet>(m_chain.get(), WalletLocation(), CreateMockWalletDatabase()); { LOCK2(wallet->cs_wallet, ::cs_main); wallet->SetLastBlockProcessed(::ChainActive().Height(), ::ChainActive().Tip()->GetBlockHash()); @@ -535,7 +541,6 @@ public: return it->second; } - NodeContext m_node; std::unique_ptr<interfaces::Chain> m_chain = interfaces::MakeChain(m_node); std::unique_ptr<CWallet> wallet; }; @@ -605,7 +610,7 @@ BOOST_FIXTURE_TEST_CASE(wallet_disableprivkeys, TestChain100Setup) { NodeContext node; auto chain = interfaces::MakeChain(node); - std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(chain.get(), WalletLocation(), WalletDatabase::CreateDummy()); + std::shared_ptr<CWallet> wallet = std::make_shared<CWallet>(chain.get(), WalletLocation(), CreateDummyWalletDatabase()); wallet->SetupLegacyScriptPubKeyMan(); wallet->SetMinVersion(FEATURE_LATEST); wallet->SetWalletFlag(WALLET_FLAG_DISABLE_PRIVATE_KEYS); @@ -625,13 +630,13 @@ static size_t CalculateNestedKeyhashInputSize(bool use_max_sig) CPubKey pubkey = key.GetPubKey(); // Generate pubkey hash - uint160 key_hash(Hash160(pubkey.begin(), pubkey.end())); + uint160 key_hash(Hash160(pubkey)); // Create inner-script to enter into keystore. Key hash can't be 0... CScript inner_script = CScript() << OP_0 << std::vector<unsigned char>(key_hash.begin(), key_hash.end()); // Create outer P2SH script for the output - uint160 script_id(Hash160(inner_script.begin(), inner_script.end())); + uint160 script_id(Hash160(inner_script)); CScript script_pubkey = CScript() << OP_HASH160 << std::vector<unsigned char>(script_id.begin(), script_id.end()) << OP_EQUAL; // Add inner-script to key store and key to watchonly @@ -791,4 +796,37 @@ BOOST_FIXTURE_TEST_CASE(CreateWalletFromFile, TestChain100Setup) TestUnloadWallet(std::move(wallet)); } +BOOST_FIXTURE_TEST_CASE(ZapSelectTx, TestChain100Setup) +{ + auto chain = interfaces::MakeChain(m_node); + auto wallet = TestLoadWallet(*chain); + CKey key; + key.MakeNewKey(true); + AddKey(*wallet, key); + + std::string error; + m_coinbase_txns.push_back(CreateAndProcessBlock({}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())).vtx[0]); + auto block_tx = TestSimpleSpend(*m_coinbase_txns[0], 0, coinbaseKey, GetScriptForRawPubKey(key.GetPubKey())); + CreateAndProcessBlock({block_tx}, GetScriptForRawPubKey(coinbaseKey.GetPubKey())); + + SyncWithValidationInterfaceQueue(); + + { + auto block_hash = block_tx.GetHash(); + auto prev_hash = m_coinbase_txns[0]->GetHash(); + + LOCK(wallet->cs_wallet); + BOOST_CHECK(wallet->HasWalletSpend(prev_hash)); + BOOST_CHECK_EQUAL(wallet->mapWallet.count(block_hash), 1u); + + std::vector<uint256> vHashIn{ block_hash }, vHashOut; + BOOST_CHECK_EQUAL(wallet->ZapSelectTx(vHashIn, vHashOut), DBErrors::LOAD_OK); + + BOOST_CHECK(!wallet->HasWalletSpend(prev_hash)); + BOOST_CHECK_EQUAL(wallet->mapWallet.count(block_hash), 0u); + } + + TestUnloadWallet(std::move(wallet)); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index 7824563254..cee2f2214c 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -21,6 +21,7 @@ #include <script/descriptor.h> #include <script/script.h> #include <script/signingprovider.h> +#include <txmempool.h> #include <util/bip32.h> #include <util/check.h> #include <util/error.h> @@ -76,12 +77,6 @@ bool RemoveWallet(const std::shared_ptr<CWallet>& wallet) return true; } -bool HasWallets() -{ - LOCK(cs_wallets); - return !vpwallets.empty(); -} - std::vector<std::shared_ptr<CWallet>> GetWallets() { LOCK(cs_wallets); @@ -104,9 +99,11 @@ std::unique_ptr<interfaces::Handler> HandleLoadWallet(LoadWalletFn load_wallet) return interfaces::MakeHandler([it] { LOCK(cs_wallets); g_load_wallet_fns.erase(it); }); } +static Mutex g_loading_wallet_mutex; static Mutex g_wallet_release_mutex; static std::condition_variable g_wallet_release_cv; -static std::set<std::string> g_unloading_wallet_set; +static std::set<std::string> g_loading_wallet_set GUARDED_BY(g_loading_wallet_mutex); +static std::set<std::string> g_unloading_wallet_set GUARDED_BY(g_wallet_release_mutex); // Custom deleter for shared_ptr<CWallet>. static void ReleaseWallet(CWallet* wallet) @@ -150,7 +147,8 @@ void UnloadWallet(std::shared_ptr<CWallet>&& wallet) } } -std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error, std::vector<bilingual_str>& warnings) +namespace { +std::shared_ptr<CWallet> LoadWalletInternal(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error, std::vector<bilingual_str>& warnings) { try { if (!CWallet::Verify(chain, location, error, warnings)) { @@ -171,6 +169,19 @@ std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, const WalletLocati return nullptr; } } +} // namespace + +std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error, std::vector<bilingual_str>& warnings) +{ + auto result = WITH_LOCK(g_loading_wallet_mutex, return g_loading_wallet_set.insert(location.GetName())); + if (!result.second) { + error = Untranslated("Wallet already being loading."); + return nullptr; + } + auto wallet = LoadWalletInternal(chain, location, error, warnings); + WITH_LOCK(g_loading_wallet_mutex, g_loading_wallet_set.erase(result.first)); + return wallet; +} std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, const std::string& name, bilingual_str& error, std::vector<bilingual_str>& warnings) { @@ -428,9 +439,14 @@ bool CWallet::HasWalletSpend(const uint256& txid) const return (iter != mapTxSpends.end() && iter->first.hash == txid); } -void CWallet::Flush(bool shutdown) +void CWallet::Flush() +{ + database->Flush(); +} + +void CWallet::Close() { - database->Flush(shutdown); + database->Close(); } void CWallet::SyncMetaData(std::pair<TxSpends::iterator, TxSpends::iterator> range) @@ -751,7 +767,6 @@ void CWallet::SetSpentKeyState(WalletBatch& batch, const uint256& hash, unsigned bool CWallet::IsSpentKey(const uint256& hash, unsigned int n) const { AssertLockHeld(cs_wallet); - CTxDestination dst; const CWalletTx* srctx = GetWalletTx(hash); if (srctx) { assert(srctx->tx->vout.size() > n); @@ -1100,23 +1115,52 @@ void CWallet::SyncTransaction(const CTransactionRef& ptx, CWalletTx::Confirmatio MarkInputsDirty(ptx); } -void CWallet::transactionAddedToMempool(const CTransactionRef& ptx) { +void CWallet::transactionAddedToMempool(const CTransactionRef& tx) { LOCK(cs_wallet); - CWalletTx::Confirmation confirm(CWalletTx::Status::UNCONFIRMED, /* block_height */ 0, {}, /* nIndex */ 0); - SyncTransaction(ptx, confirm); + SyncTransaction(tx, {CWalletTx::Status::UNCONFIRMED, /* block height */ 0, /* block hash */ {}, /* index */ 0}); - auto it = mapWallet.find(ptx->GetHash()); + auto it = mapWallet.find(tx->GetHash()); if (it != mapWallet.end()) { it->second.fInMempool = true; } } -void CWallet::transactionRemovedFromMempool(const CTransactionRef &ptx) { +void CWallet::transactionRemovedFromMempool(const CTransactionRef& tx, MemPoolRemovalReason reason) { LOCK(cs_wallet); - auto it = mapWallet.find(ptx->GetHash()); + auto it = mapWallet.find(tx->GetHash()); if (it != mapWallet.end()) { it->second.fInMempool = false; } + // Handle transactions that were removed from the mempool because they + // conflict with transactions in a newly connected block. + if (reason == MemPoolRemovalReason::CONFLICT) { + // Call SyncNotifications, so external -walletnotify notifications will + // be triggered for these transactions. Set Status::UNCONFIRMED instead + // of Status::CONFLICTED for a few reasons: + // + // 1. The transactionRemovedFromMempool callback does not currently + // provide the conflicting block's hash and height, and for backwards + // compatibility reasons it may not be not safe to store conflicted + // wallet transactions with a null block hash. See + // https://github.com/bitcoin/bitcoin/pull/18600#discussion_r420195993. + // 2. For most of these transactions, the wallet's internal conflict + // detection in the blockConnected handler will subsequently call + // MarkConflicted and update them with CONFLICTED status anyway. This + // applies to any wallet transaction that has inputs spent in the + // block, or that has ancestors in the wallet with inputs spent by + // the block. + // 3. Longstanding behavior since the sync implementation in + // https://github.com/bitcoin/bitcoin/pull/9371 and the prior sync + // implementation before that was to mark these transactions + // unconfirmed rather than conflicted. + // + // Nothing described above should be seen as an unchangeable requirement + // when improving this code in the future. The wallet's heuristics for + // distinguishing between conflicted and unconfirmed transactions are + // imperfect, and could be improved in general, see + // https://github.com/bitcoin-core/bitcoin-devwiki/wiki/Wallet-Transaction-Conflict-Tracking + SyncTransaction(tx, {CWalletTx::Status::UNCONFIRMED, /* block height */ 0, /* block hash */ {}, /* index */ 0}); + } } void CWallet::blockConnected(const CBlock& block, int height) @@ -1127,9 +1171,8 @@ void CWallet::blockConnected(const CBlock& block, int height) m_last_block_processed_height = height; m_last_block_processed = block_hash; for (size_t index = 0; index < block.vtx.size(); index++) { - CWalletTx::Confirmation confirm(CWalletTx::Status::CONFIRMED, height, block_hash, index); - SyncTransaction(block.vtx[index], confirm); - transactionRemovedFromMempool(block.vtx[index]); + SyncTransaction(block.vtx[index], {CWalletTx::Status::CONFIRMED, height, block_hash, (int)index}); + transactionRemovedFromMempool(block.vtx[index], MemPoolRemovalReason::BLOCK); } } @@ -1144,8 +1187,7 @@ void CWallet::blockDisconnected(const CBlock& block, int height) m_last_block_processed_height = height - 1; m_last_block_processed = block.hashPrevBlock; for (const CTransactionRef& ptx : block.vtx) { - CWalletTx::Confirmation confirm(CWalletTx::Status::UNCONFIRMED, /* block_height */ 0, {}, /* nIndex */ 0); - SyncTransaction(ptx, confirm); + SyncTransaction(ptx, {CWalletTx::Status::UNCONFIRMED, /* block height */ 0, /* block hash */ {}, /* index */ 0}); } } @@ -1385,19 +1427,28 @@ bool CWallet::IsWalletFlagSet(uint64_t flag) const return (m_wallet_flags & flag); } -bool CWallet::SetWalletFlags(uint64_t overwriteFlags, bool memonly) +bool CWallet::LoadWalletFlags(uint64_t flags) { LOCK(cs_wallet); - m_wallet_flags = overwriteFlags; - if (((overwriteFlags & KNOWN_WALLET_FLAGS) >> 32) ^ (overwriteFlags >> 32)) { + if (((flags & KNOWN_WALLET_FLAGS) >> 32) ^ (flags >> 32)) { // contains unknown non-tolerable wallet flags return false; } - if (!memonly && !WalletBatch(*database).WriteWalletFlags(m_wallet_flags)) { + m_wallet_flags = flags; + + return true; +} + +bool CWallet::AddWalletFlags(uint64_t flags) +{ + LOCK(cs_wallet); + // We should never be writing unknown non-tolerable wallet flags + assert(((flags & KNOWN_WALLET_FLAGS) >> 32) == (flags >> 32)); + if (!WalletBatch(*database).WriteWalletFlags(flags)) { throw std::runtime_error(std::string(__func__) + ": writing wallet flags failed"); } - return true; + return LoadWalletFlags(flags); } int64_t CWalletTx::GetTxTime() const @@ -1685,8 +1736,7 @@ CWallet::ScanResult CWallet::ScanForWalletTransactions(const uint256& start_bloc break; } for (size_t posInBlock = 0; posInBlock < block.vtx.size(); ++posInBlock) { - CWalletTx::Confirmation confirm(CWalletTx::Status::CONFIRMED, block_height, block_hash, posInBlock); - SyncTransaction(block.vtx[posInBlock], confirm, fUpdate); + SyncTransaction(block.vtx[posInBlock], {CWalletTx::Status::CONFIRMED, block_height, block_hash, (int)posInBlock}, fUpdate); } // scan succeeded, record block as most recent successfully scanned result.last_scanned_block = block_hash; @@ -2140,6 +2190,11 @@ void CWallet::AvailableCoins(std::vector<COutput>& vCoins, bool fOnlySafe, const } for (unsigned int i = 0; i < wtx.tx->vout.size(); i++) { + // Only consider selected coins if add_inputs is false + if (coinControl && !coinControl->m_add_inputs && !coinControl->IsSelected(COutPoint(entry.first, i))) { + continue; + } + if (wtx.tx->vout[i].nValue < nMinimumAmount || wtx.tx->vout[i].nValue > nMaximumAmount) continue; @@ -2451,8 +2506,11 @@ bool CWallet::SignTransaction(CMutableTransaction& tx, const std::map<COutPoint, return false; } -TransactionError CWallet::FillPSBT(PartiallySignedTransaction& psbtx, bool& complete, int sighash_type, bool sign, bool bip32derivs) const +TransactionError CWallet::FillPSBT(PartiallySignedTransaction& psbtx, bool& complete, int sighash_type, bool sign, bool bip32derivs, size_t * n_signed) const { + if (n_signed) { + *n_signed = 0; + } LOCK(cs_wallet); // Get all of the previous transactions for (unsigned int i = 0; i < psbtx.tx->vin.size(); ++i) { @@ -2463,13 +2521,8 @@ TransactionError CWallet::FillPSBT(PartiallySignedTransaction& psbtx, bool& comp continue; } - // Verify input looks sane. This will check that we have at most one uxto, witness or non-witness. - if (!input.IsSane()) { - return TransactionError::INVALID_PSBT; - } - // If we have no utxo, grab it from the wallet. - if (!input.non_witness_utxo && input.witness_utxo.IsNull()) { + if (!input.non_witness_utxo) { const uint256& txhash = txin.prevout.hash; const auto it = mapWallet.find(txhash); if (it != mapWallet.end()) { @@ -2483,10 +2536,15 @@ TransactionError CWallet::FillPSBT(PartiallySignedTransaction& psbtx, bool& comp // Fill in information from ScriptPubKeyMans for (ScriptPubKeyMan* spk_man : GetAllScriptPubKeyMans()) { - TransactionError res = spk_man->FillPSBT(psbtx, sighash_type, sign, bip32derivs); + int n_signed_this_spkm = 0; + TransactionError res = spk_man->FillPSBT(psbtx, sighash_type, sign, bip32derivs, &n_signed_this_spkm); if (res != TransactionError::OK) { return res; } + + if (n_signed) { + (*n_signed) += n_signed_this_spkm; + } } // Complete if every input is now signed @@ -2620,11 +2678,11 @@ static uint32_t GetLocktimeForNewTransaction(interfaces::Chain& chain, const uin return locktime; } -OutputType CWallet::TransactionChangeType(OutputType change_type, const std::vector<CRecipient>& vecSend) +OutputType CWallet::TransactionChangeType(const Optional<OutputType>& change_type, const std::vector<CRecipient>& vecSend) { // If -changetype is specified, always use that change type. - if (change_type != OutputType::CHANGE_AUTO) { - return change_type; + if (change_type) { + return *change_type; } // if m_default_address_type is legacy, use legacy address as change (even @@ -2721,6 +2779,12 @@ bool CWallet::CreateTransaction(const std::vector<CRecipient>& vecSend, CTransac // Get the fee rate to use effective values in coin selection CFeeRate nFeeRateNeeded = GetMinimumFeeRate(*this, coin_control, &feeCalc); + // Do not, ever, assume that it's fine to change the fee rate if the user has explicitly + // provided one + if (coin_control.m_feerate && nFeeRateNeeded > *coin_control.m_feerate) { + error = strprintf(_("Fee rate (%s) is lower than the minimum fee rate setting (%s)"), coin_control.m_feerate->ToString(), nFeeRateNeeded.ToString()); + return false; + } nFeeRet = 0; bool pick_new_inputs = true; @@ -2970,7 +3034,7 @@ bool CWallet::CreateTransaction(const std::vector<CRecipient>& vecSend, CTransac } if (nFeeRet > m_default_max_tx_fee) { - error = Untranslated(TransactionErrorString(TransactionError::MAX_FEE_EXCEEDED)); + error = TransactionErrorString(TransactionError::MAX_FEE_EXCEEDED); return false; } @@ -3070,9 +3134,11 @@ DBErrors CWallet::ZapSelectTx(std::vector<uint256>& vHashIn, std::vector<uint256 { AssertLockHeld(cs_wallet); DBErrors nZapSelectTxRet = WalletBatch(*database, "cr+").ZapSelectTx(vHashIn, vHashOut); - for (uint256 hash : vHashOut) { + for (const uint256& hash : vHashOut) { const auto& it = mapWallet.find(hash); wtxOrdered.erase(it->second.m_it_wtxOrdered); + for (const auto& txin : it->second.tx->vin) + mapTxSpends.erase(txin.prevout); mapWallet.erase(it); NotifyTransactionChanged(this, hash, CT_DELETED); } @@ -3679,18 +3745,14 @@ bool CWallet::Verify(interfaces::Chain& chain, const WalletLocation& location, b } // Keep same database environment instance across Verify/Recover calls below. - std::unique_ptr<WalletDatabase> database = WalletDatabase::Create(wallet_path); + std::unique_ptr<WalletDatabase> database = CreateWalletDatabase(wallet_path); try { - if (!WalletBatch::VerifyEnvironment(wallet_path, error_string)) { - return false; - } + return database->Verify(error_string); } catch (const fs::filesystem_error& e) { error_string = Untranslated(strprintf("Error loading wallet %s. %s", location.GetName(), fsbridge::get_filesystem_error_message(e))); return false; } - - return WalletBatch::VerifyDatabaseFile(wallet_path, error_string); } std::shared_ptr<CWallet> CWallet::CreateWalletFromFile(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error, std::vector<bilingual_str>& warnings, uint64_t wallet_creation_flags) @@ -3703,7 +3765,7 @@ std::shared_ptr<CWallet> CWallet::CreateWalletFromFile(interfaces::Chain& chain, if (gArgs.GetBoolArg("-zapwallettxes", false)) { chain.initMessage(_("Zapping all transactions from wallet...").translated); - std::unique_ptr<CWallet> tempWallet = MakeUnique<CWallet>(&chain, location, WalletDatabase::Create(location.GetPath())); + std::unique_ptr<CWallet> tempWallet = MakeUnique<CWallet>(&chain, location, CreateWalletDatabase(location.GetPath())); DBErrors nZapWalletRet = tempWallet->ZapWalletTx(vWtx); if (nZapWalletRet != DBErrors::LOAD_OK) { error = strprintf(_("Error loading %s: Wallet corrupted"), walletFile); @@ -3717,7 +3779,7 @@ std::shared_ptr<CWallet> CWallet::CreateWalletFromFile(interfaces::Chain& chain, bool fFirstRun = true; // TODO: Can't use std::make_shared because we need a custom deleter but // should be possible to use std::allocate_shared. - std::shared_ptr<CWallet> walletInstance(new CWallet(&chain, location, WalletDatabase::Create(location.GetPath())), ReleaseWallet); + std::shared_ptr<CWallet> walletInstance(new CWallet(&chain, location, CreateWalletDatabase(location.GetPath())), ReleaseWallet); DBErrors nLoadWalletRet = walletInstance->LoadWallet(fFirstRun); if (nLoadWalletRet != DBErrors::LOAD_OK) { if (nLoadWalletRet == DBErrors::CORRUPT) { @@ -3750,7 +3812,7 @@ std::shared_ptr<CWallet> CWallet::CreateWalletFromFile(interfaces::Chain& chain, // ensure this wallet.dat can only be opened by clients supporting HD with chain split and expects no default key walletInstance->SetMinVersion(FEATURE_LATEST); - walletInstance->SetWalletFlags(wallet_creation_flags, false); + walletInstance->AddWalletFlags(wallet_creation_flags); // Only create LegacyScriptPubKeyMan when not descriptor wallet if (!walletInstance->IsWalletFlagSet(WALLET_FLAG_DESCRIPTORS)) { @@ -3787,14 +3849,20 @@ std::shared_ptr<CWallet> CWallet::CreateWalletFromFile(interfaces::Chain& chain, } } - if (!gArgs.GetArg("-addresstype", "").empty() && !ParseOutputType(gArgs.GetArg("-addresstype", ""), walletInstance->m_default_address_type)) { - error = strprintf(_("Unknown address type '%s'"), gArgs.GetArg("-addresstype", "")); - return nullptr; + if (!gArgs.GetArg("-addresstype", "").empty()) { + if (!ParseOutputType(gArgs.GetArg("-addresstype", ""), walletInstance->m_default_address_type)) { + error = strprintf(_("Unknown address type '%s'"), gArgs.GetArg("-addresstype", "")); + return nullptr; + } } - if (!gArgs.GetArg("-changetype", "").empty() && !ParseOutputType(gArgs.GetArg("-changetype", ""), walletInstance->m_default_change_type)) { - error = strprintf(_("Unknown change type '%s'"), gArgs.GetArg("-changetype", "")); - return nullptr; + if (!gArgs.GetArg("-changetype", "").empty()) { + OutputType out_type; + if (!ParseOutputType(gArgs.GetArg("-changetype", ""), out_type)) { + error = strprintf(_("Unknown change type '%s'"), gArgs.GetArg("-changetype", "")); + return nullptr; + } + walletInstance->m_default_change_type = out_type; } if (gArgs.IsArgSet("-mintxfee")) { @@ -4365,12 +4433,21 @@ void CWallet::SetupDescriptorScriptPubKeyMans() spk_manager->SetupDescriptorGeneration(master_key, t); uint256 id = spk_manager->GetID(); m_spk_managers[id] = std::move(spk_manager); - SetActiveScriptPubKeyMan(id, t, internal); + AddActiveScriptPubKeyMan(id, t, internal); } } } -void CWallet::SetActiveScriptPubKeyMan(uint256 id, OutputType type, bool internal, bool memonly) +void CWallet::AddActiveScriptPubKeyMan(uint256 id, OutputType type, bool internal) +{ + WalletBatch batch(*database); + if (!batch.WriteActiveScriptPubKeyMan(static_cast<uint8_t>(type), id, internal)) { + throw std::runtime_error(std::string(__func__) + ": writing active ScriptPubKeyMan id failed"); + } + LoadActiveScriptPubKeyMan(id, type, internal); +} + +void CWallet::LoadActiveScriptPubKeyMan(uint256 id, OutputType type, bool internal) { WalletLogPrintf("Setting spkMan to active: id = %s, type = %d, internal = %d\n", id.ToString(), static_cast<int>(type), static_cast<int>(internal)); auto& spk_mans = internal ? m_internal_spk_managers : m_external_spk_managers; @@ -4378,12 +4455,6 @@ void CWallet::SetActiveScriptPubKeyMan(uint256 id, OutputType type, bool interna spk_man->SetInternal(internal); spk_mans[type] = spk_man; - if (!memonly) { - WalletBatch batch(*database); - if (!batch.WriteActiveScriptPubKeyMan(static_cast<uint8_t>(type), id, internal)) { - throw std::runtime_error(std::string(__func__) + ": writing active ScriptPubKeyMan id failed"); - } - } NotifyCanGetAddressesChanged(); } diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index e3141baef0..a761caf38c 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -13,11 +13,11 @@ #include <policy/feerate.h> #include <psbt.h> #include <tinyformat.h> -#include <ui_interface.h> #include <util/message.h> #include <util/strencodings.h> #include <util/string.h> #include <util/system.h> +#include <util/ui_change_type.h> #include <validationinterface.h> #include <wallet/coinselection.h> #include <wallet/crypter.h> @@ -51,7 +51,6 @@ void UnloadWallet(std::shared_ptr<CWallet>&& wallet); bool AddWallet(const std::shared_ptr<CWallet>& wallet); bool RemoveWallet(const std::shared_ptr<CWallet>& wallet); -bool HasWallets(); std::vector<std::shared_ptr<CWallet>> GetWallets(); std::shared_ptr<CWallet> GetWallet(const std::string& name); std::shared_ptr<CWallet> LoadWallet(interfaces::Chain& chain, const WalletLocation& location, bilingual_str& error, std::vector<bilingual_str>& warnings); @@ -106,9 +105,6 @@ class ReserveDestination; //! Default for -addresstype constexpr OutputType DEFAULT_ADDRESS_TYPE{OutputType::BECH32}; -//! Default for -changetype -constexpr OutputType DEFAULT_CHANGE_TYPE{OutputType::CHANGE_AUTO}; - static constexpr uint64_t KNOWN_WALLET_FLAGS = WALLET_FLAG_AVOID_REUSE | WALLET_FLAG_BLANK_WALLET @@ -921,7 +917,7 @@ public: uint256 last_failed_block; }; ScanResult ScanForWalletTransactions(const uint256& start_block, int start_height, Optional<int> max_height, const WalletRescanReserver& reserver, bool fUpdate); - void transactionRemovedFromMempool(const CTransactionRef &ptx) override; + void transactionRemovedFromMempool(const CTransactionRef& tx, MemPoolRemovalReason reason) override; void ReacceptWalletTransactions() EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); void ResendWalletTransactions(); struct Balance { @@ -935,7 +931,7 @@ public: Balance GetBalance(int min_depth = 0, bool avoid_reuse = true) const; CAmount GetAvailableBalance(const CCoinControl* coinControl = nullptr) const; - OutputType TransactionChangeType(OutputType change_type, const std::vector<CRecipient>& vecSend); + OutputType TransactionChangeType(const Optional<OutputType>& change_type, const std::vector<CRecipient>& vecSend); /** * Insert additional inputs into the transaction by @@ -965,7 +961,8 @@ public: bool& complete, int sighash_type = 1 /* SIGHASH_ALL */, bool sign = true, - bool bip32derivs = true) const; + bool bip32derivs = true, + size_t* n_signed = nullptr) const; /** * Create a new transaction paying the recipients with a set of coins @@ -1012,7 +1009,13 @@ public: CFeeRate m_fallback_fee{DEFAULT_FALLBACK_FEE}; CFeeRate m_discard_rate{DEFAULT_DISCARD_FEE}; OutputType m_default_address_type{DEFAULT_ADDRESS_TYPE}; - OutputType m_default_change_type{DEFAULT_CHANGE_TYPE}; + /** + * Default output type for change outputs. When unset, automatically choose type + * based on address type setting and the types other of non-change outputs + * (see -changetype option documentation and implementation in + * CWallet::TransactionChangeType for details). + */ + Optional<OutputType> m_default_change_type{}; /** Absolute maximum transaction fee (in satoshis) used by default for the wallet */ CAmount m_default_max_tx_fee{DEFAULT_TRANSACTION_MAXFEE}; @@ -1084,7 +1087,10 @@ public: bool HasWalletSpend(const uint256& txid) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); //! Flush wallet (bitdb flush) - void Flush(bool shutdown=false); + void Flush(); + + //! Close wallet database + void Close(); /** Wallet is about to be unloaded */ boost::signals2::signal<void ()> NotifyUnload; @@ -1173,7 +1179,9 @@ public: /** overwrite all flags by the given uint64_t returns false if unknown, non-tolerable flags are present */ - bool SetWalletFlags(uint64_t overwriteFlags, bool memOnly); + bool AddWalletFlags(uint64_t flags); + /** Loads the flags into the wallet. (used by LoadWallet) */ + bool LoadWalletFlags(uint64_t flags); /** Determine if we are a legacy wallet */ bool IsLegacy() const; @@ -1251,12 +1259,17 @@ public: //! Instantiate a descriptor ScriptPubKeyMan from the WalletDescriptor and load it void LoadDescriptorScriptPubKeyMan(uint256 id, WalletDescriptor& desc); - //! Sets the active ScriptPubKeyMan for the specified type and internal + //! Adds the active ScriptPubKeyMan for the specified type and internal. Writes it to the wallet file + //! @param[in] id The unique id for the ScriptPubKeyMan + //! @param[in] type The OutputType this ScriptPubKeyMan provides addresses for + //! @param[in] internal Whether this ScriptPubKeyMan provides change addresses + void AddActiveScriptPubKeyMan(uint256 id, OutputType type, bool internal); + + //! Loads an active ScriptPubKeyMan for the specified type and internal. (used by LoadWallet) //! @param[in] id The unique id for the ScriptPubKeyMan //! @param[in] type The OutputType this ScriptPubKeyMan provides addresses for //! @param[in] internal Whether this ScriptPubKeyMan provides change addresses - //! @param[in] memonly Whether to record this update to the database. Set to true for wallet loading, normally false when actually updating the wallet. - void SetActiveScriptPubKeyMan(uint256 id, OutputType type, bool internal, bool memonly = false); + void LoadActiveScriptPubKeyMan(uint256 id, OutputType type, bool internal); //! Create new DescriptorScriptPubKeyMans and add them to the wallet void SetupDescriptorScriptPubKeyMans(); diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index e7adbfea77..fa6814d0d3 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -18,8 +18,6 @@ #include <atomic> #include <string> -#include <boost/thread.hpp> - namespace DBKeys { const std::string ACENTRY{"acentry"}; const std::string ACTIVEEXTERNALSPK{"activeexternalspk"}; @@ -105,7 +103,7 @@ bool WalletBatch::WriteKey(const CPubKey& vchPubKey, const CPrivKey& vchPrivKey, vchKey.insert(vchKey.end(), vchPubKey.begin(), vchPubKey.end()); vchKey.insert(vchKey.end(), vchPrivKey.begin(), vchPrivKey.end()); - return WriteIC(std::make_pair(DBKeys::KEY, vchPubKey), std::make_pair(vchPrivKey, Hash(vchKey.begin(), vchKey.end())), false); + return WriteIC(std::make_pair(DBKeys::KEY, vchPubKey), std::make_pair(vchPrivKey, Hash(vchKey)), false); } bool WalletBatch::WriteCryptedKey(const CPubKey& vchPubKey, @@ -117,13 +115,13 @@ bool WalletBatch::WriteCryptedKey(const CPubKey& vchPubKey, } // Compute a checksum of the encrypted key - uint256 checksum = Hash(vchCryptedSecret.begin(), vchCryptedSecret.end()); + uint256 checksum = Hash(vchCryptedSecret); const auto key = std::make_pair(DBKeys::CRYPTED_KEY, vchPubKey); if (!WriteIC(key, std::make_pair(vchCryptedSecret, checksum), false)) { // It may already exist, so try writing just the checksum std::vector<unsigned char> val; - if (!m_batch.Read(key, val)) { + if (!m_batch->Read(key, val)) { return false; } if (!WriteIC(key, std::make_pair(val, checksum), true)) { @@ -168,8 +166,8 @@ bool WalletBatch::WriteBestBlock(const CBlockLocator& locator) bool WalletBatch::ReadBestBlock(CBlockLocator& locator) { - if (m_batch.Read(DBKeys::BESTBLOCK, locator) && !locator.vHave.empty()) return true; - return m_batch.Read(DBKeys::BESTBLOCK_NOMERKLE, locator); + if (m_batch->Read(DBKeys::BESTBLOCK, locator) && !locator.vHave.empty()) return true; + return m_batch->Read(DBKeys::BESTBLOCK_NOMERKLE, locator); } bool WalletBatch::WriteOrderPosNext(int64_t nOrderPosNext) @@ -179,7 +177,7 @@ bool WalletBatch::WriteOrderPosNext(int64_t nOrderPosNext) bool WalletBatch::ReadPool(int64_t nPool, CKeyPool& keypool) { - return m_batch.Read(std::make_pair(DBKeys::POOL, nPool), keypool); + return m_batch->Read(std::make_pair(DBKeys::POOL, nPool), keypool); } bool WalletBatch::WritePool(int64_t nPool, const CKeyPool& keypool) @@ -211,7 +209,7 @@ bool WalletBatch::WriteDescriptorKey(const uint256& desc_id, const CPubKey& pubk key.insert(key.end(), pubkey.begin(), pubkey.end()); key.insert(key.end(), privkey.begin(), privkey.end()); - return WriteIC(std::make_pair(DBKeys::WALLETDESCRIPTORKEY, std::make_pair(desc_id, pubkey)), std::make_pair(privkey, Hash(key.begin(), key.end())), false); + return WriteIC(std::make_pair(DBKeys::WALLETDESCRIPTORKEY, std::make_pair(desc_id, pubkey)), std::make_pair(privkey, Hash(key)), false); } bool WalletBatch::WriteCryptedDescriptorKey(const uint256& desc_id, const CPubKey& pubkey, const std::vector<unsigned char>& secret) @@ -367,7 +365,7 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, vchKey.insert(vchKey.end(), vchPubKey.begin(), vchPubKey.end()); vchKey.insert(vchKey.end(), pkey.begin(), pkey.end()); - if (Hash(vchKey.begin(), vchKey.end()) != hash) + if (Hash(vchKey) != hash) { strErr = "Error reading wallet database: CPubKey/CPrivKey corrupt"; return false; @@ -416,7 +414,7 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, if (!ssValue.eof()) { uint256 checksum; ssValue >> checksum; - if ((checksum_valid = Hash(vchPrivKey.begin(), vchPrivKey.end()) != checksum)) { + if ((checksum_valid = Hash(vchPrivKey) != checksum)) { strErr = "Error reading wallet database: Crypted key corrupt"; return false; } @@ -441,10 +439,11 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, // Extract some CHDChain info from this metadata if it has any if (keyMeta.nVersion >= CKeyMetadata::VERSION_WITH_HDDATA && !keyMeta.hd_seed_id.IsNull() && keyMeta.hdKeypath.size() > 0) { // Get the path from the key origin or from the path string - // Not applicable when path is "s" as that indicates a seed + // Not applicable when path is "s" or "m" as those indicate a seed + // See https://github.com/bitcoin/bitcoin/pull/12924 bool internal = false; uint32_t index = 0; - if (keyMeta.hdKeypath != "s") { + if (keyMeta.hdKeypath != "s" && keyMeta.hdKeypath != "m") { std::vector<uint32_t> path; if (keyMeta.has_key_origin) { // We have a key origin, so pull it from its path vector @@ -540,11 +539,11 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, } else if (strType == DBKeys::HDCHAIN) { CHDChain chain; ssValue >> chain; - pwallet->GetOrCreateLegacyScriptPubKeyMan()->SetHDChain(chain, true); + pwallet->GetOrCreateLegacyScriptPubKeyMan()->LoadHDChain(chain); } else if (strType == DBKeys::FLAGS) { uint64_t flags; ssValue >> flags; - if (!pwallet->SetWalletFlags(flags, true)) { + if (!pwallet->LoadWalletFlags(flags)) { strErr = "Error reading wallet database: Unknown non-tolerable wallet flags found"; return false; } @@ -593,9 +592,6 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, ssValue >> ser_xpub; CExtPubKey xpub; xpub.Decode(ser_xpub.data()); - if (wss.m_descriptor_caches.count(desc_id)) { - wss.m_descriptor_caches[desc_id] = DescriptorCache(); - } if (parent) { wss.m_descriptor_caches[desc_id].CacheParentExtPubKey(key_exp_index, xpub); } else { @@ -625,7 +621,7 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, to_hash.insert(to_hash.end(), pubkey.begin(), pubkey.end()); to_hash.insert(to_hash.end(), pkey.begin(), pkey.end()); - if (Hash(to_hash.begin(), to_hash.end()) != hash) + if (Hash(to_hash) != hash) { strErr = "Error reading wallet database: CPubKey/CPrivKey corrupt"; return false; @@ -694,15 +690,14 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet) LOCK(pwallet->cs_wallet); try { int nMinVersion = 0; - if (m_batch.Read(DBKeys::MINVERSION, nMinVersion)) { + if (m_batch->Read(DBKeys::MINVERSION, nMinVersion)) { if (nMinVersion > FEATURE_LATEST) return DBErrors::TOO_NEW; pwallet->LoadMinVersion(nMinVersion); } // Get cursor - Dbc* pcursor = m_batch.GetCursor(); - if (!pcursor) + if (!m_batch->StartCursor()) { pwallet->WalletLogPrintf("Error getting wallet database cursor\n"); return DBErrors::CORRUPT; @@ -713,11 +708,14 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet) // Read next record CDataStream ssKey(SER_DISK, CLIENT_VERSION); CDataStream ssValue(SER_DISK, CLIENT_VERSION); - int ret = m_batch.ReadAtCursor(pcursor, ssKey, ssValue); - if (ret == DB_NOTFOUND) + bool complete; + bool ret = m_batch->ReadAtCursor(ssKey, ssValue, complete); + if (complete) { break; - else if (ret != 0) + } + else if (!ret) { + m_batch->CloseCursor(); pwallet->WalletLogPrintf("Error reading next record from wallet database\n"); return DBErrors::CORRUPT; } @@ -744,21 +742,17 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet) if (!strErr.empty()) pwallet->WalletLogPrintf("%s\n", strErr); } - pcursor->close(); - } - catch (const boost::thread_interrupted&) { - throw; - } - catch (...) { + } catch (...) { result = DBErrors::CORRUPT; } + m_batch->CloseCursor(); // Set the active ScriptPubKeyMans for (auto spk_man_pair : wss.m_active_external_spks) { - pwallet->SetActiveScriptPubKeyMan(spk_man_pair.second, spk_man_pair.first, /* internal */ false, /* memonly */ true); + pwallet->LoadActiveScriptPubKeyMan(spk_man_pair.second, spk_man_pair.first, /* internal */ false); } for (auto spk_man_pair : wss.m_active_internal_spks) { - pwallet->SetActiveScriptPubKeyMan(spk_man_pair.second, spk_man_pair.first, /* internal */ true, /* memonly */ true); + pwallet->LoadActiveScriptPubKeyMan(spk_man_pair.second, spk_man_pair.first, /* internal */ true); } // Set the descriptor caches @@ -788,7 +782,7 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet) // Last client version to open this wallet, was previously the file version number int last_client = CLIENT_VERSION; - m_batch.Read(DBKeys::VERSION, last_client); + m_batch->Read(DBKeys::VERSION, last_client); int wallet_version = pwallet->GetVersion(); pwallet->WalletLogPrintf("Wallet File Version = %d\n", wallet_version > 0 ? wallet_version : last_client); @@ -813,7 +807,7 @@ DBErrors WalletBatch::LoadWallet(CWallet* pwallet) return DBErrors::NEED_REWRITE; if (last_client < CLIENT_VERSION) // Update - m_batch.Write(DBKeys::VERSION, CLIENT_VERSION); + m_batch->Write(DBKeys::VERSION, CLIENT_VERSION); if (wss.fAnyUnordered) result = pwallet->ReorderTransactions(); @@ -849,14 +843,13 @@ DBErrors WalletBatch::FindWalletTx(std::vector<uint256>& vTxHash, std::list<CWal try { int nMinVersion = 0; - if (m_batch.Read(DBKeys::MINVERSION, nMinVersion)) { + if (m_batch->Read(DBKeys::MINVERSION, nMinVersion)) { if (nMinVersion > FEATURE_LATEST) return DBErrors::TOO_NEW; } // Get cursor - Dbc* pcursor = m_batch.GetCursor(); - if (!pcursor) + if (!m_batch->StartCursor()) { LogPrintf("Error getting wallet database cursor\n"); return DBErrors::CORRUPT; @@ -867,11 +860,12 @@ DBErrors WalletBatch::FindWalletTx(std::vector<uint256>& vTxHash, std::list<CWal // Read next record CDataStream ssKey(SER_DISK, CLIENT_VERSION); CDataStream ssValue(SER_DISK, CLIENT_VERSION); - int ret = m_batch.ReadAtCursor(pcursor, ssKey, ssValue); - if (ret == DB_NOTFOUND) + bool complete; + bool ret = m_batch->ReadAtCursor(ssKey, ssValue, complete); + if (complete) { break; - else if (ret != 0) - { + } else if (!ret) { + m_batch->CloseCursor(); LogPrintf("Error reading next record from wallet database\n"); return DBErrors::CORRUPT; } @@ -886,14 +880,10 @@ DBErrors WalletBatch::FindWalletTx(std::vector<uint256>& vTxHash, std::list<CWal ssValue >> vWtx.back(); } } - pcursor->close(); - } - catch (const boost::thread_interrupted&) { - throw; - } - catch (...) { + } catch (...) { result = DBErrors::CORRUPT; } + m_batch->CloseCursor(); return result; } @@ -959,9 +949,6 @@ void MaybeCompactWalletDB() if (fOneThread.exchange(true)) { return; } - if (!gArgs.GetBoolArg("-flushwallet", DEFAULT_FLUSHWALLET)) { - return; - } for (const std::shared_ptr<CWallet>& pwallet : GetWallets()) { WalletDatabase& dbh = pwallet->GetDBHandle(); @@ -974,7 +961,7 @@ void MaybeCompactWalletDB() } if (dbh.nLastFlushed != nUpdateCounter && GetTime() - dbh.nLastWalletUpdate >= 2) { - if (BerkeleyBatch::PeriodicFlush(dbh)) { + if (dbh.PeriodicFlush()) { dbh.nLastFlushed = nUpdateCounter; } } @@ -983,16 +970,6 @@ void MaybeCompactWalletDB() fOneThread = false; } -bool WalletBatch::VerifyEnvironment(const fs::path& wallet_path, bilingual_str& errorStr) -{ - return BerkeleyBatch::VerifyEnvironment(wallet_path, errorStr); -} - -bool WalletBatch::VerifyDatabaseFile(const fs::path& wallet_path, bilingual_str& errorStr) -{ - return BerkeleyBatch::VerifyDatabaseFile(wallet_path, errorStr); -} - bool WalletBatch::WriteDestData(const std::string &address, const std::string &key, const std::string &value) { return WriteIC(std::make_pair(DBKeys::DESTDATA, std::make_pair(address, key)), value); @@ -1016,15 +993,39 @@ bool WalletBatch::WriteWalletFlags(const uint64_t flags) bool WalletBatch::TxnBegin() { - return m_batch.TxnBegin(); + return m_batch->TxnBegin(); } bool WalletBatch::TxnCommit() { - return m_batch.TxnCommit(); + return m_batch->TxnCommit(); } bool WalletBatch::TxnAbort() { - return m_batch.TxnAbort(); + return m_batch->TxnAbort(); +} + +bool IsWalletLoaded(const fs::path& wallet_path) +{ + return IsBDBWalletLoaded(wallet_path); +} + +/** Return object for accessing database at specified path. */ +std::unique_ptr<WalletDatabase> CreateWalletDatabase(const fs::path& path) +{ + std::string filename; + return MakeUnique<BerkeleyDatabase>(GetWalletEnv(path, filename), std::move(filename)); +} + +/** Return object for accessing dummy database with no read/write capabilities. */ +std::unique_ptr<WalletDatabase> CreateDummyWalletDatabase() +{ + return MakeUnique<DummyDatabase>(); +} + +/** Return object for accessing temporary in-memory database. */ +std::unique_ptr<WalletDatabase> CreateMockWalletDatabase() +{ + return MakeUnique<BerkeleyDatabase>(std::make_shared<BerkeleyEnvironment>(), ""); } diff --git a/src/wallet/walletdb.h b/src/wallet/walletdb.h index b95ed24d12..7c5bf7652b 100644 --- a/src/wallet/walletdb.h +++ b/src/wallet/walletdb.h @@ -8,6 +8,7 @@ #include <amount.h> #include <script/sign.h> +#include <wallet/bdb.h> #include <wallet/db.h> #include <wallet/walletutil.h> #include <key.h> @@ -39,9 +40,6 @@ class CWalletTx; class uint160; class uint256; -/** Backend-agnostic database type. */ -using WalletDatabase = BerkeleyDatabase; - /** Error statuses for the wallet database */ enum class DBErrors { @@ -182,12 +180,12 @@ private: template <typename K, typename T> bool WriteIC(const K& key, const T& value, bool fOverwrite = true) { - if (!m_batch.Write(key, value, fOverwrite)) { + if (!m_batch->Write(key, value, fOverwrite)) { return false; } m_database.IncrementUpdateCounter(); if (m_database.nUpdateCounter % 1000 == 0) { - m_batch.Flush(); + m_batch->Flush(); } return true; } @@ -195,19 +193,19 @@ private: template <typename K> bool EraseIC(const K& key) { - if (!m_batch.Erase(key)) { + if (!m_batch->Erase(key)) { return false; } m_database.IncrementUpdateCounter(); if (m_database.nUpdateCounter % 1000 == 0) { - m_batch.Flush(); + m_batch->Flush(); } return true; } public: explicit WalletBatch(WalletDatabase& database, const char* pszMode = "r+", bool _fFlushOnClose = true) : - m_batch(database, pszMode, _fFlushOnClose), + m_batch(database.MakeBatch(pszMode, _fFlushOnClose)), m_database(database) { } @@ -279,7 +277,7 @@ public: //! Abort current transaction bool TxnAbort(); private: - BerkeleyBatch m_batch; + std::unique_ptr<DatabaseBatch> m_batch; WalletDatabase& m_database; }; @@ -289,4 +287,16 @@ void MaybeCompactWalletDB(); //! Unserialize a given Key-Value pair and load it into the wallet bool ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, std::string& strType, std::string& strErr); +/** Return whether a wallet database is currently loaded. */ +bool IsWalletLoaded(const fs::path& wallet_path); + +/** Return object for accessing database at specified path. */ +std::unique_ptr<WalletDatabase> CreateWalletDatabase(const fs::path& path); + +/** Return object for accessing dummy database with no read/write capabilities. */ +std::unique_ptr<WalletDatabase> CreateDummyWalletDatabase(); + +/** Return object for accessing temporary in-memory database. */ +std::unique_ptr<WalletDatabase> CreateMockWalletDatabase(); + #endif // BITCOIN_WALLET_WALLETDB_H diff --git a/src/wallet/wallettool.cpp b/src/wallet/wallettool.cpp index be07c28503..9f25b1ae7d 100644 --- a/src/wallet/wallettool.cpp +++ b/src/wallet/wallettool.cpp @@ -17,7 +17,7 @@ namespace WalletTool { static void WalletToolReleaseWallet(CWallet* wallet) { wallet->WalletLogPrintf("Releasing wallet\n"); - wallet->Flush(true); + wallet->Close(); delete wallet; } @@ -28,7 +28,7 @@ static std::shared_ptr<CWallet> CreateWallet(const std::string& name, const fs:: return nullptr; } // dummy chain interface - std::shared_ptr<CWallet> wallet_instance(new CWallet(nullptr /* chain */, WalletLocation(name), WalletDatabase::Create(path)), WalletToolReleaseWallet); + std::shared_ptr<CWallet> wallet_instance(new CWallet(nullptr /* chain */, WalletLocation(name), CreateWalletDatabase(path)), WalletToolReleaseWallet); LOCK(wallet_instance->cs_wallet); bool first_run = true; DBErrors load_wallet_ret = wallet_instance->LoadWallet(first_run); @@ -57,7 +57,7 @@ static std::shared_ptr<CWallet> LoadWallet(const std::string& name, const fs::pa } // dummy chain interface - std::shared_ptr<CWallet> wallet_instance(new CWallet(nullptr /* chain */, WalletLocation(name), WalletDatabase::Create(path)), WalletToolReleaseWallet); + std::shared_ptr<CWallet> wallet_instance(new CWallet(nullptr /* chain */, WalletLocation(name), CreateWalletDatabase(path)), WalletToolReleaseWallet); DBErrors load_wallet_ret; try { bool first_run; @@ -107,12 +107,12 @@ static void WalletShowInfo(CWallet* wallet_instance) static bool SalvageWallet(const fs::path& path) { // Create a Database handle to allow for the db to be initialized before recovery - std::unique_ptr<WalletDatabase> database = WalletDatabase::Create(path); + std::unique_ptr<WalletDatabase> database = CreateWalletDatabase(path); // Initialize the environment before recovery bilingual_str error_string; try { - WalletBatch::VerifyEnvironment(path, error_string); + database->Verify(error_string); } catch (const fs::filesystem_error& e) { error_string = Untranslated(strprintf("Error loading wallet. %s", fsbridge::get_filesystem_error_message(e))); } @@ -133,24 +133,19 @@ bool ExecuteWalletToolFunc(const std::string& command, const std::string& name) std::shared_ptr<CWallet> wallet_instance = CreateWallet(name, path); if (wallet_instance) { WalletShowInfo(wallet_instance.get()); - wallet_instance->Flush(true); + wallet_instance->Close(); } } else if (command == "info" || command == "salvage") { if (!fs::exists(path)) { tfm::format(std::cerr, "Error: no wallet file at %s\n", name); return false; } - bilingual_str error; - if (!WalletBatch::VerifyEnvironment(path, error)) { - tfm::format(std::cerr, "%s\nError loading %s. Is wallet being used by other process?\n", error.original, name); - return false; - } if (command == "info") { std::shared_ptr<CWallet> wallet_instance = LoadWallet(name, path); if (!wallet_instance) return false; WalletShowInfo(wallet_instance.get()); - wallet_instance->Flush(true); + wallet_instance->Close(); } else if (command == "salvage") { return SalvageWallet(path); } diff --git a/src/walletinitinterface.h b/src/walletinitinterface.h index f4730273f1..a55e02f2dc 100644 --- a/src/walletinitinterface.h +++ b/src/walletinitinterface.h @@ -5,6 +5,8 @@ #ifndef BITCOIN_WALLETINITINTERFACE_H #define BITCOIN_WALLETINITINTERFACE_H +class ArgsManager; + struct NodeContext; class WalletInitInterface { @@ -12,7 +14,7 @@ public: /** Is the wallet component enabled */ virtual bool HasWalletSupport() const = 0; /** Get wallet help string */ - virtual void AddWalletOptions() const = 0; + virtual void AddWalletOptions(ArgsManager& argsman) const = 0; /** Check wallet parameter interaction */ virtual bool ParameterInteraction() const = 0; /** Add wallets that should be opened to list of chain clients. */ diff --git a/src/warnings.cpp b/src/warnings.cpp index 467c3d0f65..501bf7e637 100644 --- a/src/warnings.cpp +++ b/src/warnings.cpp @@ -6,66 +6,71 @@ #include <warnings.h> #include <sync.h> +#include <util/string.h> #include <util/system.h> #include <util/translation.h> -static RecursiveMutex cs_warnings; -static std::string strMiscWarning GUARDED_BY(cs_warnings); -static bool fLargeWorkForkFound GUARDED_BY(cs_warnings) = false; -static bool fLargeWorkInvalidChainFound GUARDED_BY(cs_warnings) = false; +#include <vector> -void SetMiscWarning(const std::string& strWarning) +static Mutex g_warnings_mutex; +static bilingual_str g_misc_warnings GUARDED_BY(g_warnings_mutex); +static bool fLargeWorkForkFound GUARDED_BY(g_warnings_mutex) = false; +static bool fLargeWorkInvalidChainFound GUARDED_BY(g_warnings_mutex) = false; + +void SetMiscWarning(const bilingual_str& warning) { - LOCK(cs_warnings); - strMiscWarning = strWarning; + LOCK(g_warnings_mutex); + g_misc_warnings = warning; } void SetfLargeWorkForkFound(bool flag) { - LOCK(cs_warnings); + LOCK(g_warnings_mutex); fLargeWorkForkFound = flag; } bool GetfLargeWorkForkFound() { - LOCK(cs_warnings); + LOCK(g_warnings_mutex); return fLargeWorkForkFound; } void SetfLargeWorkInvalidChainFound(bool flag) { - LOCK(cs_warnings); + LOCK(g_warnings_mutex); fLargeWorkInvalidChainFound = flag; } -std::string GetWarnings(bool verbose) +bilingual_str GetWarnings(bool verbose) { - std::string warnings_concise; - std::string warnings_verbose; - const std::string warning_separator = "<hr />"; + bilingual_str warnings_concise; + std::vector<bilingual_str> warnings_verbose; - LOCK(cs_warnings); + LOCK(g_warnings_mutex); // Pre-release build warning if (!CLIENT_VERSION_IS_RELEASE) { - warnings_concise = "This is a pre-release test build - use at your own risk - do not use for mining or merchant applications"; - warnings_verbose = _("This is a pre-release test build - use at your own risk - do not use for mining or merchant applications").translated; + warnings_concise = _("This is a pre-release test build - use at your own risk - do not use for mining or merchant applications"); + warnings_verbose.emplace_back(warnings_concise); } // Misc warnings like out of disk space and clock is wrong - if (strMiscWarning != "") { - warnings_concise = strMiscWarning; - warnings_verbose += (warnings_verbose.empty() ? "" : warning_separator) + strMiscWarning; + if (!g_misc_warnings.empty()) { + warnings_concise = g_misc_warnings; + warnings_verbose.emplace_back(warnings_concise); } if (fLargeWorkForkFound) { - warnings_concise = "Warning: The network does not appear to fully agree! Some miners appear to be experiencing issues."; - warnings_verbose += (warnings_verbose.empty() ? "" : warning_separator) + _("Warning: The network does not appear to fully agree! Some miners appear to be experiencing issues.").translated; + warnings_concise = _("Warning: The network does not appear to fully agree! Some miners appear to be experiencing issues."); + warnings_verbose.emplace_back(warnings_concise); } else if (fLargeWorkInvalidChainFound) { - warnings_concise = "Warning: We do not appear to fully agree with our peers! You may need to upgrade, or other nodes may need to upgrade."; - warnings_verbose += (warnings_verbose.empty() ? "" : warning_separator) + _("Warning: We do not appear to fully agree with our peers! You may need to upgrade, or other nodes may need to upgrade.").translated; + warnings_concise = _("Warning: We do not appear to fully agree with our peers! You may need to upgrade, or other nodes may need to upgrade."); + warnings_verbose.emplace_back(warnings_concise); + } + + if (verbose) { + return Join(warnings_verbose, Untranslated("<hr />")); } - if (verbose) return warnings_verbose; - else return warnings_concise; + return warnings_concise; } diff --git a/src/warnings.h b/src/warnings.h index 83b1add1ee..28546eb753 100644 --- a/src/warnings.h +++ b/src/warnings.h @@ -8,16 +8,18 @@ #include <string> -void SetMiscWarning(const std::string& strWarning); +struct bilingual_str; + +void SetMiscWarning(const bilingual_str& warning); void SetfLargeWorkForkFound(bool flag); bool GetfLargeWorkForkFound(); void SetfLargeWorkInvalidChainFound(bool flag); /** Format a string that describes several potential problems detected by the core. * @param[in] verbose bool - * - if true, get all warnings, translated (where possible), separated by <hr /> + * - if true, get all warnings separated by <hr /> * - if false, get the most important warning * @returns the warning string */ -std::string GetWarnings(bool verbose); +bilingual_str GetWarnings(bool verbose); #endif // BITCOIN_WARNINGS_H |