diff options
-rw-r--r-- | src/wallet/test/walletload_tests.cpp | 30 | ||||
-rw-r--r-- | src/wallet/walletdb.cpp | 6 |
2 files changed, 35 insertions, 1 deletions
diff --git a/src/wallet/test/walletload_tests.cpp b/src/wallet/test/walletload_tests.cpp index 73a4b77188..58166ae492 100644 --- a/src/wallet/test/walletload_tests.cpp +++ b/src/wallet/test/walletload_tests.cpp @@ -4,6 +4,7 @@ #include <wallet/test/util.h> #include <wallet/wallet.h> +#include <test/util/logging.h> #include <test/util/setup_common.h> #include <boost/test/unit_test.hpp> @@ -32,7 +33,7 @@ public: void ExpandPrivate(int pos, const SigningProvider& provider, FlatSigningProvider& out) const override {} }; -BOOST_FIXTURE_TEST_CASE(wallet_load_unknown_descriptor, TestingSetup) +BOOST_FIXTURE_TEST_CASE(wallet_load_descriptors, TestingSetup) { std::unique_ptr<WalletDatabase> database = CreateMockableWalletDatabase(); { @@ -49,6 +50,33 @@ BOOST_FIXTURE_TEST_CASE(wallet_load_unknown_descriptor, TestingSetup) const std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", std::move(database))); BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::UNKNOWN_DESCRIPTOR); } + + // Test 2 + // Now write a valid descriptor with an invalid ID. + // As the software produces another ID for the descriptor, the loading process must be aborted. + database = CreateMockableWalletDatabase(); + + // Verify the error + bool found = false; + DebugLogHelper logHelper("The descriptor ID calculated by the wallet differs from the one in DB", [&](const std::string* s) { + found = true; + return false; + }); + + { + // Write valid descriptor with invalid ID + WalletBatch batch(*database, false); + std::string desc = "wpkh([d34db33f/84h/0h/0h]xpub6DJ2dNUysrn5Vt36jH2KLBT2i1auw1tTSSomg8PhqNiUtx8QX2SvC9nrHu81fT41fvDUnhMjEzQgXnQjKEu3oaqMSzhSrHMxyyoEAmUHQbY/0/*)#cjjspncu"; + WalletDescriptor wallet_descriptor(std::make_shared<DummyDescriptor>(desc), 0, 0, 0, 0); + BOOST_CHECK(batch.WriteDescriptor(uint256::ONE, wallet_descriptor)); + } + + { + // Now try to load the wallet and verify the error. + const std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", std::move(database))); + BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::CORRUPT); + BOOST_CHECK(found); // The error must be logged + } } bool HasAnyRecordOfType(WalletDatabase& db, const std::string& key) diff --git a/src/wallet/walletdb.cpp b/src/wallet/walletdb.cpp index 2aee750ced..e5a73af385 100644 --- a/src/wallet/walletdb.cpp +++ b/src/wallet/walletdb.cpp @@ -803,6 +803,12 @@ static DBErrors LoadDescriptorWalletRecords(CWallet* pwallet, DatabaseBatch& bat } pwallet->LoadDescriptorScriptPubKeyMan(id, desc); + // Prior to doing anything with this spkm, verify ID compatibility + if (id != pwallet->GetDescriptorScriptPubKeyMan(desc)->GetID()) { + strErr = "The descriptor ID calculated by the wallet differs from the one in DB"; + return DBErrors::CORRUPT; + } + DescriptorCache cache; // Get key cache for this descriptor |