diff options
Diffstat (limited to 'src/wallet/test/walletload_tests.cpp')
-rw-r--r-- | src/wallet/test/walletload_tests.cpp | 97 |
1 files changed, 44 insertions, 53 deletions
diff --git a/src/wallet/test/walletload_tests.cpp b/src/wallet/test/walletload_tests.cpp index 9f5a4b14d3..6823eafdfa 100644 --- a/src/wallet/test/walletload_tests.cpp +++ b/src/wallet/test/walletload_tests.cpp @@ -34,7 +34,7 @@ public: BOOST_FIXTURE_TEST_CASE(wallet_load_unknown_descriptor, TestingSetup) { - std::unique_ptr<WalletDatabase> database = CreateMockWalletDatabase(); + std::unique_ptr<WalletDatabase> database = CreateMockableWalletDatabase(); { // Write unknown active descriptor WalletBatch batch(*database, false); @@ -70,38 +70,45 @@ bool HasAnyRecordOfType(WalletDatabase& db, const std::string& key) return false; } -BOOST_FIXTURE_TEST_CASE(wallet_load_verif_crypted_key_checksum, TestingSetup) +template<typename... Args> +SerializeData MakeSerializeData(const Args&... args) { - // The test duplicates the db so each case has its own db instance. - int NUMBER_OF_TESTS = 4; - std::vector<std::unique_ptr<WalletDatabase>> dbs; - CKey first_key; - auto get_db = [](std::vector<std::unique_ptr<WalletDatabase>>& dbs) { - std::unique_ptr<WalletDatabase> db = std::move(dbs.back()); - dbs.pop_back(); - return db; - }; - - { // Context setup. + CDataStream s(0, 0); + SerializeMany(s, args...); + return {s.begin(), s.end()}; +} + + +BOOST_FIXTURE_TEST_CASE(wallet_load_ckey, TestingSetup) +{ + SerializeData ckey_record_key; + SerializeData ckey_record_value; + std::map<SerializeData, SerializeData> records; + + { + // Context setup. // Create and encrypt legacy wallet - std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockWalletDatabase())); + std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase())); LOCK(wallet->cs_wallet); auto legacy_spkm = wallet->GetOrCreateLegacyScriptPubKeyMan(); BOOST_CHECK(legacy_spkm->SetupGeneration(true)); - // Get the first key in the wallet + // Retrieve a key CTxDestination dest = *Assert(legacy_spkm->GetNewDestination(OutputType::LEGACY)); CKeyID key_id = GetKeyForDestination(*legacy_spkm, dest); + CKey first_key; BOOST_CHECK(legacy_spkm->GetKey(key_id, first_key)); - // Encrypt the wallet and duplicate database + // Encrypt the wallet BOOST_CHECK(wallet->EncryptWallet("encrypt")); wallet->Flush(); - DatabaseOptions options; - for (int i=0; i < NUMBER_OF_TESTS; i++) { - dbs.emplace_back(DuplicateMockDatabase(wallet->GetDatabase(), options)); - } + // Store a copy of all the records + records = GetMockableDatabase(*wallet).m_records; + + // Get the record for the retrieved key + ckey_record_key = MakeSerializeData(DBKeys::CRYPTED_KEY, first_key.GetPubKey()); + ckey_record_value = records.at(ckey_record_key); } { @@ -112,7 +119,7 @@ BOOST_FIXTURE_TEST_CASE(wallet_load_verif_crypted_key_checksum, TestingSetup) // the records every time that 'CWallet::Unlock' gets called, which is not good. // Load the wallet and check that is encrypted - std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", get_db(dbs))); + std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase(records))); BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::LOAD_OK); BOOST_CHECK(wallet->IsCrypted()); BOOST_CHECK(HasAnyRecordOfType(wallet->GetDatabase(), DBKeys::CRYPTED_KEY)); @@ -127,18 +134,12 @@ BOOST_FIXTURE_TEST_CASE(wallet_load_verif_crypted_key_checksum, TestingSetup) { // Second test case: // Verify that loading up a 'ckey' with no checksum triggers a complete re-write of the crypted keys. - std::unique_ptr<WalletDatabase> db = get_db(dbs); - { - std::unique_ptr<DatabaseBatch> batch = db->MakeBatch(false); - std::pair<std::vector<unsigned char>, uint256> value; - BOOST_CHECK(batch->Read(std::make_pair(DBKeys::CRYPTED_KEY, first_key.GetPubKey()), value)); - const auto key = std::make_pair(DBKeys::CRYPTED_KEY, first_key.GetPubKey()); - BOOST_CHECK(batch->Write(key, value.first, /*fOverwrite=*/true)); - } + // Cut off the 32 byte checksum from a ckey record + records[ckey_record_key].resize(ckey_record_value.size() - 32); // Load the wallet and check that is encrypted - std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", std::move(db))); + std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase(records))); BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::LOAD_OK); BOOST_CHECK(wallet->IsCrypted()); BOOST_CHECK(HasAnyRecordOfType(wallet->GetDatabase(), DBKeys::CRYPTED_KEY)); @@ -154,35 +155,25 @@ BOOST_FIXTURE_TEST_CASE(wallet_load_verif_crypted_key_checksum, TestingSetup) { // Third test case: // Verify that loading up a 'ckey' with an invalid checksum throws an error. - std::unique_ptr<WalletDatabase> db = get_db(dbs); - { - std::unique_ptr<DatabaseBatch> batch = db->MakeBatch(false); - std::vector<unsigned char> crypted_data; - BOOST_CHECK(batch->Read(std::make_pair(DBKeys::CRYPTED_KEY, first_key.GetPubKey()), crypted_data)); - - // Write an invalid checksum - std::pair<std::vector<unsigned char>, uint256> value = std::make_pair(crypted_data, uint256::ONE); - const auto key = std::make_pair(DBKeys::CRYPTED_KEY, first_key.GetPubKey()); - BOOST_CHECK(batch->Write(key, value, /*fOverwrite=*/true)); - } - - std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", std::move(db))); + + // Cut off the 32 byte checksum from a ckey record + records[ckey_record_key].resize(ckey_record_value.size() - 32); + // Fill in the checksum space with 0s + records[ckey_record_key].resize(ckey_record_value.size()); + + std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase(records))); BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::CORRUPT); } { // Fourth test case: // Verify that loading up a 'ckey' with an invalid pubkey throws an error - std::unique_ptr<WalletDatabase> db = get_db(dbs); - { - CPubKey invalid_key; - BOOST_ASSERT(!invalid_key.IsValid()); - const auto key = std::make_pair(DBKeys::CRYPTED_KEY, invalid_key); - std::pair<std::vector<unsigned char>, uint256> value; - BOOST_CHECK(db->MakeBatch(false)->Write(key, value, /*fOverwrite=*/true)); - } - - std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", std::move(db))); + CPubKey invalid_key; + BOOST_ASSERT(!invalid_key.IsValid()); + SerializeData key = MakeSerializeData(DBKeys::CRYPTED_KEY, invalid_key); + records[key] = ckey_record_value; + + std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase(records))); BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::CORRUPT); } } |