diff options
author | Andrew Chow <achow101-github@achow101.com> | 2022-04-11 17:11:37 -0400 |
---|---|---|
committer | Andrew Chow <github@achow101.com> | 2023-06-01 13:09:08 -0400 |
commit | ba616b932cb9e9adb7eb9f1826caa62ce422a22d (patch) | |
tree | 032c9cf74574b1f29a39c79b76fc9495353188b9 /src/wallet | |
parent | 1d858b055daeea363e0450f327672658548be4c6 (diff) |
wallet: Add GetPrefixCursor to DatabaseBatch
In order to get records beginning with a prefix, we will need a cursor
specifically for that prefix. So add a GetPrefixCursor function and
DatabaseCursor classes for dealing with those prefixes.
Tested on each supported db engine.
1) Write two different key->value elements to db.
2) Create a new prefix cursor and walk-through every returned element,
verifying that it gets parsed properly.
3) Try to move the cursor outside the filtered range: expect failure
and flag complete=true.
Co-Authored-By: Ryan Ofsky <ryan@ofsky.org>
Co-Authored-By: furszy <matiasfurszyfer@protonmail.com>
Diffstat (limited to 'src/wallet')
-rw-r--r-- | src/wallet/bdb.cpp | 26 | ||||
-rw-r--r-- | src/wallet/bdb.h | 7 | ||||
-rw-r--r-- | src/wallet/db.h | 1 | ||||
-rw-r--r-- | src/wallet/salvage.cpp | 1 | ||||
-rw-r--r-- | src/wallet/sqlite.cpp | 44 | ||||
-rw-r--r-- | src/wallet/sqlite.h | 10 | ||||
-rw-r--r-- | src/wallet/test/db_tests.cpp | 128 | ||||
-rw-r--r-- | src/wallet/test/util.cpp | 13 | ||||
-rw-r--r-- | src/wallet/test/util.h | 22 | ||||
-rw-r--r-- | src/wallet/test/walletload_tests.cpp | 2 |
10 files changed, 239 insertions, 15 deletions
diff --git a/src/wallet/bdb.cpp b/src/wallet/bdb.cpp index 7a1916ddc3..68abdcd81e 100644 --- a/src/wallet/bdb.cpp +++ b/src/wallet/bdb.cpp @@ -668,7 +668,8 @@ void BerkeleyDatabase::ReloadDbEnv() env->ReloadDbEnv(); } -BerkeleyCursor::BerkeleyCursor(BerkeleyDatabase& database, const BerkeleyBatch& batch) +BerkeleyCursor::BerkeleyCursor(BerkeleyDatabase& database, const BerkeleyBatch& batch, Span<const std::byte> prefix) + : m_key_prefix(prefix.begin(), prefix.end()) { if (!database.m_db.get()) { throw std::runtime_error(STR_INTERNAL_BUG("BerkeleyDatabase does not exist")); @@ -685,9 +686,15 @@ DatabaseCursor::Status BerkeleyCursor::Next(DataStream& ssKey, DataStream& ssVal { if (m_cursor == nullptr) return Status::FAIL; // Read at cursor - SafeDbt datKey; + SafeDbt datKey(m_key_prefix.data(), m_key_prefix.size()); SafeDbt datValue; - int ret = m_cursor->get(datKey, datValue, DB_NEXT); + int ret = -1; + if (m_first && !m_key_prefix.empty()) { + ret = m_cursor->get(datKey, datValue, DB_SET_RANGE); + } else { + ret = m_cursor->get(datKey, datValue, DB_NEXT); + } + m_first = false; if (ret == DB_NOTFOUND) { return Status::DONE; } @@ -695,9 +702,14 @@ DatabaseCursor::Status BerkeleyCursor::Next(DataStream& ssKey, DataStream& ssVal return Status::FAIL; } + Span<const std::byte> raw_key = {AsBytePtr(datKey.get_data()), datKey.get_size()}; + if (!m_key_prefix.empty() && std::mismatch(raw_key.begin(), raw_key.end(), m_key_prefix.begin(), m_key_prefix.end()).second != m_key_prefix.end()) { + return Status::DONE; + } + // Convert to streams ssKey.clear(); - ssKey.write({AsBytePtr(datKey.get_data()), datKey.get_size()}); + ssKey.write(raw_key); ssValue.clear(); ssValue.write({AsBytePtr(datValue.get_data()), datValue.get_size()}); return Status::MORE; @@ -716,6 +728,12 @@ std::unique_ptr<DatabaseCursor> BerkeleyBatch::GetNewCursor() return std::make_unique<BerkeleyCursor>(m_database, *this); } +std::unique_ptr<DatabaseCursor> BerkeleyBatch::GetNewPrefixCursor(Span<const std::byte> prefix) +{ + if (!pdb) return nullptr; + return std::make_unique<BerkeleyCursor>(m_database, *this, prefix); +} + bool BerkeleyBatch::TxnBegin() { if (!pdb || activeTxn) diff --git a/src/wallet/bdb.h b/src/wallet/bdb.h index e8a57e8a5e..8cc03692d6 100644 --- a/src/wallet/bdb.h +++ b/src/wallet/bdb.h @@ -190,9 +190,13 @@ class BerkeleyCursor : public DatabaseCursor { private: Dbc* m_cursor; + std::vector<std::byte> m_key_prefix; + bool m_first{true}; public: - explicit BerkeleyCursor(BerkeleyDatabase& database, const BerkeleyBatch& batch); + // Constructor for cursor for records matching the prefix + // To match all records, an empty prefix may be provided. + explicit BerkeleyCursor(BerkeleyDatabase& database, const BerkeleyBatch& batch, Span<const std::byte> prefix = {}); ~BerkeleyCursor() override; Status Next(DataStream& key, DataStream& value) override; @@ -229,6 +233,7 @@ public: void Close() override; std::unique_ptr<DatabaseCursor> GetNewCursor() override; + std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) override; bool TxnBegin() override; bool TxnCommit() override; bool TxnAbort() override; diff --git a/src/wallet/db.h b/src/wallet/db.h index b4ccd13a9a..9d684225c3 100644 --- a/src/wallet/db.h +++ b/src/wallet/db.h @@ -113,6 +113,7 @@ public: virtual bool ErasePrefix(Span<const std::byte> prefix) = 0; virtual std::unique_ptr<DatabaseCursor> GetNewCursor() = 0; + virtual std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) = 0; virtual bool TxnBegin() = 0; virtual bool TxnCommit() = 0; virtual bool TxnAbort() = 0; diff --git a/src/wallet/salvage.cpp b/src/wallet/salvage.cpp index ab73e67285..e303310273 100644 --- a/src/wallet/salvage.cpp +++ b/src/wallet/salvage.cpp @@ -43,6 +43,7 @@ public: void Close() override {} std::unique_ptr<DatabaseCursor> GetNewCursor() override { return std::make_unique<DummyCursor>(); } + std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) override { return GetNewCursor(); } bool TxnBegin() override { return true; } bool TxnCommit() override { return true; } bool TxnAbort() override { return true; } diff --git a/src/wallet/sqlite.cpp b/src/wallet/sqlite.cpp index 55245707ed..fe10f911c4 100644 --- a/src/wallet/sqlite.cpp +++ b/src/wallet/sqlite.cpp @@ -9,6 +9,7 @@ #include <logging.h> #include <sync.h> #include <util/fs_helpers.h> +#include <util/check.h> #include <util/strencodings.h> #include <util/translation.h> #include <wallet/db.h> @@ -515,6 +516,7 @@ DatabaseCursor::Status SQLiteCursor::Next(DataStream& key, DataStream& value) SQLiteCursor::~SQLiteCursor() { + sqlite3_clear_bindings(m_cursor_stmt); sqlite3_reset(m_cursor_stmt); int res = sqlite3_finalize(m_cursor_stmt); if (res != SQLITE_OK) { @@ -538,6 +540,48 @@ std::unique_ptr<DatabaseCursor> SQLiteBatch::GetNewCursor() return cursor; } +std::unique_ptr<DatabaseCursor> SQLiteBatch::GetNewPrefixCursor(Span<const std::byte> prefix) +{ + if (!m_database.m_db) return nullptr; + + // To get just the records we want, the SQL statement does a comparison of the binary data + // where the data must be greater than or equal to the prefix, and less than + // the prefix incremented by one (when interpreted as an integer) + std::vector<std::byte> start_range(prefix.begin(), prefix.end()); + std::vector<std::byte> end_range(prefix.begin(), prefix.end()); + auto it = end_range.rbegin(); + for (; it != end_range.rend(); ++it) { + if (*it == std::byte(std::numeric_limits<unsigned char>::max())) { + *it = std::byte(0); + continue; + } + *it = std::byte(std::to_integer<unsigned char>(*it) + 1); + break; + } + if (it == end_range.rend()) { + // If the prefix is all 0xff bytes, clear end_range as we won't need it + end_range.clear(); + } + + auto cursor = std::make_unique<SQLiteCursor>(start_range, end_range); + if (!cursor) return nullptr; + + const char* stmt_text = end_range.empty() ? "SELECT key, value FROM main WHERE key >= ?" : + "SELECT key, value FROM main WHERE key >= ? AND key < ?"; + int res = sqlite3_prepare_v2(m_database.m_db, stmt_text, -1, &cursor->m_cursor_stmt, nullptr); + if (res != SQLITE_OK) { + throw std::runtime_error(strprintf( + "SQLiteDatabase: Failed to setup cursor SQL statement: %s\n", sqlite3_errstr(res))); + } + + if (!BindBlobToStatement(cursor->m_cursor_stmt, 1, cursor->m_prefix_range_start, "prefix_start")) return nullptr; + if (!end_range.empty()) { + if (!BindBlobToStatement(cursor->m_cursor_stmt, 2, cursor->m_prefix_range_end, "prefix_end")) return nullptr; + } + + return cursor; +} + bool SQLiteBatch::TxnBegin() { if (!m_database.m_db || sqlite3_get_autocommit(m_database.m_db) == 0) return false; diff --git a/src/wallet/sqlite.h b/src/wallet/sqlite.h index d9de40569b..0378bbb8d6 100644 --- a/src/wallet/sqlite.h +++ b/src/wallet/sqlite.h @@ -15,12 +15,21 @@ struct bilingual_str; namespace wallet { class SQLiteDatabase; +/** RAII class that provides a database cursor */ class SQLiteCursor : public DatabaseCursor { public: sqlite3_stmt* m_cursor_stmt{nullptr}; + // Copies of the prefix things for the prefix cursor. + // Prevents SQLite from accessing temp variables for the prefix things. + std::vector<std::byte> m_prefix_range_start; + std::vector<std::byte> m_prefix_range_end; explicit SQLiteCursor() {} + explicit SQLiteCursor(std::vector<std::byte> start_range, std::vector<std::byte> end_range) + : m_prefix_range_start(std::move(start_range)), + m_prefix_range_end(std::move(end_range)) + {} ~SQLiteCursor() override; Status Next(DataStream& key, DataStream& value) override; @@ -57,6 +66,7 @@ public: void Close() override; std::unique_ptr<DatabaseCursor> GetNewCursor() override; + std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) override; bool TxnBegin() override; bool TxnCommit() override; bool TxnAbort() override; diff --git a/src/wallet/test/db_tests.cpp b/src/wallet/test/db_tests.cpp index 7761308bbc..4cda35ed8d 100644 --- a/src/wallet/test/db_tests.cpp +++ b/src/wallet/test/db_tests.cpp @@ -6,13 +6,56 @@ #include <test/util/setup_common.h> #include <util/fs.h> +#include <util/translation.h> +#ifdef USE_BDB #include <wallet/bdb.h> +#endif +#ifdef USE_SQLITE +#include <wallet/sqlite.h> +#endif +#include <wallet/test/util.h> +#include <wallet/walletutil.h> // for WALLET_FLAG_DESCRIPTORS #include <fstream> #include <memory> #include <string> +inline std::ostream& operator<<(std::ostream& os, const std::pair<const SerializeData, SerializeData>& kv) +{ + Span key{kv.first}, value{kv.second}; + os << "(\"" << std::string_view{reinterpret_cast<const char*>(key.data()), key.size()} << "\", \"" + << std::string_view{reinterpret_cast<const char*>(key.data()), key.size()} << "\")"; + return os; +} + namespace wallet { + +static Span<const std::byte> StringBytes(std::string_view str) +{ + return AsBytes<const char>({str.data(), str.size()}); +} + +static SerializeData StringData(std::string_view str) +{ + auto bytes = StringBytes(str); + return SerializeData{bytes.begin(), bytes.end()}; +} + +static void CheckPrefix(DatabaseBatch& batch, Span<const std::byte> prefix, MockableData expected) +{ + std::unique_ptr<DatabaseCursor> cursor = batch.GetNewPrefixCursor(prefix); + MockableData actual; + while (true) { + DataStream key, value; + DatabaseCursor::Status status = cursor->Next(key, value); + if (status == DatabaseCursor::Status::DONE) break; + BOOST_CHECK(status == DatabaseCursor::Status::MORE); + BOOST_CHECK( + actual.emplace(SerializeData(key.begin(), key.end()), SerializeData(value.begin(), value.end())).second); + } + BOOST_CHECK_EQUAL_COLLECTIONS(actual.begin(), actual.end(), expected.begin(), expected.end()); +} + BOOST_FIXTURE_TEST_SUITE(db_tests, BasicTestingSetup) static std::shared_ptr<BerkeleyEnvironment> GetWalletEnv(const fs::path& path, fs::path& database_filename) @@ -78,5 +121,90 @@ BOOST_AUTO_TEST_CASE(getwalletenv_g_dbenvs_free_instance) BOOST_CHECK(env_2_a == env_2_b); } +static std::vector<std::unique_ptr<WalletDatabase>> TestDatabases(const fs::path& path_root) +{ + std::vector<std::unique_ptr<WalletDatabase>> dbs; + DatabaseOptions options; + DatabaseStatus status; + bilingual_str error; +#ifdef USE_BDB + dbs.emplace_back(MakeBerkeleyDatabase(path_root / "bdb", options, status, error)); +#endif +#ifdef USE_SQLITE + dbs.emplace_back(MakeSQLiteDatabase(path_root / "sqlite", options, status, error)); +#endif + dbs.emplace_back(CreateMockableWalletDatabase()); + return dbs; +} + +BOOST_AUTO_TEST_CASE(db_cursor_prefix_range_test) +{ + // Test each supported db + for (const auto& database : TestDatabases(m_path_root)) { + BOOST_ASSERT(database); + + std::vector<std::string> prefixes = {"", "FIRST", "SECOND", "P\xfe\xff", "P\xff\x01", "\xff\xff"}; + + // Write elements to it + std::unique_ptr<DatabaseBatch> handler = database->MakeBatch(); + for (unsigned int i = 0; i < 10; i++) { + for (const auto& prefix : prefixes) { + BOOST_CHECK(handler->Write(std::make_pair(prefix, i), i)); + } + } + + // Now read all the items by prefix and verify that each element gets parsed correctly + for (const auto& prefix : prefixes) { + DataStream s_prefix; + s_prefix << prefix; + std::unique_ptr<DatabaseCursor> cursor = handler->GetNewPrefixCursor(s_prefix); + DataStream key; + DataStream value; + for (int i = 0; i < 10; i++) { + DatabaseCursor::Status status = cursor->Next(key, value); + BOOST_ASSERT(status == DatabaseCursor::Status::MORE); + + std::string key_back; + unsigned int i_back; + key >> key_back >> i_back; + BOOST_CHECK_EQUAL(key_back, prefix); + + unsigned int value_back; + value >> value_back; + BOOST_CHECK_EQUAL(value_back, i_back); + } + + // Let's now read it once more, it should return DONE + BOOST_CHECK(cursor->Next(key, value) == DatabaseCursor::Status::DONE); + } + } +} + +// Lower level DatabaseBase::GetNewPrefixCursor test, to cover cases that aren't +// covered in the higher level test above. The higher level test uses +// serialized strings which are prefixed with string length, so it doesn't test +// truly empty prefixes or prefixes that begin with \xff +BOOST_AUTO_TEST_CASE(db_cursor_prefix_byte_test) +{ + const MockableData::value_type + e{StringData(""), StringData("e")}, + p{StringData("prefix"), StringData("p")}, + ps{StringData("prefixsuffix"), StringData("ps")}, + f{StringData("\xff"), StringData("f")}, + fs{StringData("\xffsuffix"), StringData("fs")}, + ff{StringData("\xff\xff"), StringData("ff")}, + ffs{StringData("\xff\xffsuffix"), StringData("ffs")}; + for (const auto& database : TestDatabases(m_path_root)) { + std::unique_ptr<DatabaseBatch> batch = database->MakeBatch(); + for (const auto& [k, v] : {e, p, ps, f, fs, ff, ffs}) { + batch->Write(MakeUCharSpan(k), MakeUCharSpan(v)); + } + CheckPrefix(*batch, StringBytes(""), {e, p, ps, f, fs, ff, ffs}); + CheckPrefix(*batch, StringBytes("prefix"), {p, ps}); + CheckPrefix(*batch, StringBytes("\xff"), {f, fs, ff, ffs}); + CheckPrefix(*batch, StringBytes("\xff\xff"), {ff, ffs}); + } +} + BOOST_AUTO_TEST_SUITE_END() } // namespace wallet diff --git a/src/wallet/test/util.cpp b/src/wallet/test/util.cpp index 4fa7adc37d..069ab25f26 100644 --- a/src/wallet/test/util.cpp +++ b/src/wallet/test/util.cpp @@ -92,6 +92,17 @@ CTxDestination getNewDestination(CWallet& w, OutputType output_type) return *Assert(w.GetNewDestination(output_type, "")); } +// BytePrefix compares equality with other byte spans that begin with the same prefix. +struct BytePrefix { Span<const std::byte> prefix; }; +bool operator<(BytePrefix a, Span<const std::byte> b) { return a.prefix < b.subspan(0, std::min(a.prefix.size(), b.size())); } +bool operator<(Span<const std::byte> a, BytePrefix b) { return a.subspan(0, std::min(a.size(), b.prefix.size())) < b.prefix; } + +MockableCursor::MockableCursor(const MockableData& records, bool pass, Span<const std::byte> prefix) +{ + m_pass = pass; + std::tie(m_cursor, m_cursor_end) = records.equal_range(BytePrefix{prefix}); +} + DatabaseCursor::Status MockableCursor::Next(DataStream& key, DataStream& value) { if (!m_pass) { @@ -175,7 +186,7 @@ bool MockableBatch::ErasePrefix(Span<const std::byte> prefix) return true; } -std::unique_ptr<WalletDatabase> CreateMockableWalletDatabase(std::map<SerializeData, SerializeData> records) +std::unique_ptr<WalletDatabase> CreateMockableWalletDatabase(MockableData records) { return std::make_unique<MockableDatabase>(records); } diff --git a/src/wallet/test/util.h b/src/wallet/test/util.h index b1ad1c959b..2a1fe639de 100644 --- a/src/wallet/test/util.h +++ b/src/wallet/test/util.h @@ -48,14 +48,17 @@ std::string getnewaddress(CWallet& w); /** Returns a new destination, of an specific type, from the wallet */ CTxDestination getNewDestination(CWallet& w, OutputType output_type); +using MockableData = std::map<SerializeData, SerializeData, std::less<>>; + class MockableCursor: public DatabaseCursor { public: - std::map<SerializeData, SerializeData>::const_iterator m_cursor; - std::map<SerializeData, SerializeData>::const_iterator m_cursor_end; + MockableData::const_iterator m_cursor; + MockableData::const_iterator m_cursor_end; bool m_pass; - explicit MockableCursor(const std::map<SerializeData, SerializeData>& records, bool pass) : m_cursor(records.begin()), m_cursor_end(records.end()), m_pass(pass) {} + explicit MockableCursor(const MockableData& records, bool pass) : m_cursor(records.begin()), m_cursor_end(records.end()), m_pass(pass) {} + MockableCursor(const MockableData& records, bool pass, Span<const std::byte> prefix); ~MockableCursor() {} Status Next(DataStream& key, DataStream& value) override; @@ -64,7 +67,7 @@ public: class MockableBatch : public DatabaseBatch { private: - std::map<SerializeData, SerializeData>& m_records; + MockableData& m_records; bool m_pass; bool ReadKey(DataStream&& key, DataStream& value) override; @@ -74,7 +77,7 @@ private: bool ErasePrefix(Span<const std::byte> prefix) override; public: - explicit MockableBatch(std::map<SerializeData, SerializeData>& records, bool pass) : m_records(records), m_pass(pass) {} + explicit MockableBatch(MockableData& records, bool pass) : m_records(records), m_pass(pass) {} ~MockableBatch() {} void Flush() override {} @@ -84,6 +87,9 @@ public: { return std::make_unique<MockableCursor>(m_records, m_pass); } + std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) override { + return std::make_unique<MockableCursor>(m_records, m_pass, prefix); + } bool TxnBegin() override { return m_pass; } bool TxnCommit() override { return m_pass; } bool TxnAbort() override { return m_pass; } @@ -94,10 +100,10 @@ public: class MockableDatabase : public WalletDatabase { public: - std::map<SerializeData, SerializeData> m_records; + MockableData m_records; bool m_pass{true}; - MockableDatabase(std::map<SerializeData, SerializeData> records = {}) : WalletDatabase(), m_records(records) {} + MockableDatabase(MockableData records = {}) : WalletDatabase(), m_records(records) {} ~MockableDatabase() {}; void Open() override {} @@ -117,7 +123,7 @@ public: std::unique_ptr<DatabaseBatch> MakeBatch(bool flush_on_close = true) override { return std::make_unique<MockableBatch>(m_records, m_pass); } }; -std::unique_ptr<WalletDatabase> CreateMockableWalletDatabase(std::map<SerializeData, SerializeData> records = {}); +std::unique_ptr<WalletDatabase> CreateMockableWalletDatabase(MockableData records = {}); MockableDatabase& GetMockableDatabase(CWallet& wallet); } // namespace wallet diff --git a/src/wallet/test/walletload_tests.cpp b/src/wallet/test/walletload_tests.cpp index 6823eafdfa..c1ff7baae1 100644 --- a/src/wallet/test/walletload_tests.cpp +++ b/src/wallet/test/walletload_tests.cpp @@ -83,7 +83,7 @@ BOOST_FIXTURE_TEST_CASE(wallet_load_ckey, TestingSetup) { SerializeData ckey_record_key; SerializeData ckey_record_value; - std::map<SerializeData, SerializeData> records; + MockableData records; { // Context setup. |