diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/rpc/client.cpp | 22 | ||||
-rw-r--r-- | src/rpc/client.h | 5 | ||||
-rw-r--r-- | src/rpc/mining.cpp | 35 | ||||
-rw-r--r-- | src/test/fuzz/parse_univalue.cpp | 9 | ||||
-rw-r--r-- | src/test/fuzz/rpc.cpp | 1 | ||||
-rw-r--r-- | src/test/fuzz/string.cpp | 4 | ||||
-rw-r--r-- | src/test/rpc_tests.cpp | 31 | ||||
-rw-r--r-- | src/txmempool.cpp | 27 | ||||
-rw-r--r-- | src/txmempool.h | 13 | ||||
-rw-r--r-- | src/univalue/test/object.cpp | 27 | ||||
-rw-r--r-- | src/wallet/bdb.cpp | 29 | ||||
-rw-r--r-- | src/wallet/bdb.h | 7 | ||||
-rw-r--r-- | src/wallet/db.h | 1 | ||||
-rw-r--r-- | src/wallet/salvage.cpp | 1 | ||||
-rw-r--r-- | src/wallet/sqlite.cpp | 76 | ||||
-rw-r--r-- | src/wallet/sqlite.h | 10 | ||||
-rw-r--r-- | src/wallet/test/db_tests.cpp | 128 | ||||
-rw-r--r-- | src/wallet/test/util.cpp | 16 | ||||
-rw-r--r-- | src/wallet/test/util.h | 22 | ||||
-rw-r--r-- | src/wallet/test/walletload_tests.cpp | 2 |
20 files changed, 391 insertions, 75 deletions
diff --git a/src/rpc/client.cpp b/src/rpc/client.cpp index 7278b57c76..edc0fb05d7 100644 --- a/src/rpc/client.cpp +++ b/src/rpc/client.cpp @@ -297,6 +297,14 @@ static const CRPCConvertParam vRPCConvertParams[] = }; // clang-format on +/** Parse string to UniValue or throw runtime_error if string contains invalid JSON */ +static UniValue Parse(std::string_view raw) +{ + UniValue parsed; + if (!parsed.read(raw)) throw std::runtime_error(tfm::format("Error parsing JSON: %s", raw)); + return parsed; +} + class CRPCConvertTable { private: @@ -309,13 +317,13 @@ public: /** Return arg_value as UniValue, and first parse it if it is a non-string parameter */ UniValue ArgToUniValue(std::string_view arg_value, const std::string& method, int param_idx) { - return members.count({method, param_idx}) > 0 ? ParseNonRFCJSONValue(arg_value) : arg_value; + return members.count({method, param_idx}) > 0 ? Parse(arg_value) : arg_value; } /** Return arg_value as UniValue, and first parse it if it is a non-string parameter */ UniValue ArgToUniValue(std::string_view arg_value, const std::string& method, const std::string& param_name) { - return membersByName.count({method, param_name}) > 0 ? ParseNonRFCJSONValue(arg_value) : arg_value; + return membersByName.count({method, param_name}) > 0 ? Parse(arg_value) : arg_value; } }; @@ -329,16 +337,6 @@ CRPCConvertTable::CRPCConvertTable() static CRPCConvertTable rpcCvtTable; -/** Non-RFC4627 JSON parser, accepts internal values (such as numbers, true, false, null) - * as well as objects and arrays. - */ -UniValue ParseNonRFCJSONValue(std::string_view raw) -{ - UniValue parsed; - if (!parsed.read(raw)) throw std::runtime_error(tfm::format("Error parsing JSON: %s", raw)); - return parsed; -} - UniValue RPCConvertValues(const std::string &strMethod, const std::vector<std::string> &strParams) { UniValue params(UniValue::VARR); diff --git a/src/rpc/client.h b/src/rpc/client.h index 3c5c4fc4d6..b67cd27fdf 100644 --- a/src/rpc/client.h +++ b/src/rpc/client.h @@ -17,9 +17,4 @@ UniValue RPCConvertValues(const std::string& strMethod, const std::vector<std::s /** Convert named arguments to command-specific RPC representation */ UniValue RPCConvertNamedValues(const std::string& strMethod, const std::vector<std::string>& strParams); -/** Non-RFC4627 JSON parser, accepts internal values (such as numbers, true, false, null) - * as well as objects and arrays. - */ -UniValue ParseNonRFCJSONValue(std::string_view raw); - #endif // BITCOIN_RPC_CLIENT_H diff --git a/src/rpc/mining.cpp b/src/rpc/mining.cpp index eb61d58a33..074cecadd2 100644 --- a/src/rpc/mining.cpp +++ b/src/rpc/mining.cpp @@ -480,6 +480,40 @@ static RPCHelpMan prioritisetransaction() }; } +static RPCHelpMan getprioritisedtransactions() +{ + return RPCHelpMan{"getprioritisedtransactions", + "Returns a map of all user-created (see prioritisetransaction) fee deltas by txid, and whether the tx is present in mempool.", + {}, + RPCResult{ + RPCResult::Type::OBJ_DYN, "prioritisation-map", "prioritisation keyed by txid", + { + {RPCResult::Type::OBJ, "txid", "", { + {RPCResult::Type::NUM, "fee_delta", "transaction fee delta in satoshis"}, + {RPCResult::Type::BOOL, "in_mempool", "whether this transaction is currently in mempool"}, + }} + }, + }, + RPCExamples{ + HelpExampleCli("getprioritisedtransactions", "") + + HelpExampleRpc("getprioritisedtransactions", "") + }, + [&](const RPCHelpMan& self, const JSONRPCRequest& request) -> UniValue + { + NodeContext& node = EnsureAnyNodeContext(request.context); + CTxMemPool& mempool = EnsureMemPool(node); + UniValue rpc_result{UniValue::VOBJ}; + for (const auto& delta_info : mempool.GetPrioritisedTransactions()) { + UniValue result_inner{UniValue::VOBJ}; + result_inner.pushKV("fee_delta", delta_info.delta); + result_inner.pushKV("in_mempool", delta_info.in_mempool); + rpc_result.pushKV(delta_info.txid.GetHex(), result_inner); + } + return rpc_result; + }, + }; +} + // NOTE: Assumes a conclusive result; if result is inconclusive, it must be handled by caller static UniValue BIP22ValidationResult(const BlockValidationState& state) @@ -1048,6 +1082,7 @@ void RegisterMiningRPCCommands(CRPCTable& t) {"mining", &getnetworkhashps}, {"mining", &getmininginfo}, {"mining", &prioritisetransaction}, + {"mining", &getprioritisedtransactions}, {"mining", &getblocktemplate}, {"mining", &submitblock}, {"mining", &submitheader}, diff --git a/src/test/fuzz/parse_univalue.cpp b/src/test/fuzz/parse_univalue.cpp index be15a38e92..6d33c1a8cc 100644 --- a/src/test/fuzz/parse_univalue.cpp +++ b/src/test/fuzz/parse_univalue.cpp @@ -22,12 +22,9 @@ FUZZ_TARGET_INIT(parse_univalue, initialize_parse_univalue) const std::string random_string(buffer.begin(), buffer.end()); bool valid = true; const UniValue univalue = [&] { - try { - return ParseNonRFCJSONValue(random_string); - } catch (const std::runtime_error&) { - valid = false; - return UniValue{}; - } + UniValue uv; + if (!uv.read(random_string)) valid = false; + return valid ? uv : UniValue{}; }(); if (!valid) { return; diff --git a/src/test/fuzz/rpc.cpp b/src/test/fuzz/rpc.cpp index 6424f756a0..b1858a1800 100644 --- a/src/test/fuzz/rpc.cpp +++ b/src/test/fuzz/rpc.cpp @@ -136,6 +136,7 @@ const std::vector<std::string> RPC_COMMANDS_SAFE_FOR_FUZZING{ "getnetworkinfo", "getnodeaddresses", "getpeerinfo", + "getprioritisedtransactions", "getrawmempool", "getrawtransaction", "getrpcinfo", diff --git a/src/test/fuzz/string.cpp b/src/test/fuzz/string.cpp index 75c78ce1bd..fd96b6e3b2 100644 --- a/src/test/fuzz/string.cpp +++ b/src/test/fuzz/string.cpp @@ -66,10 +66,6 @@ FUZZ_TARGET(string) const util::Settings settings; (void)OnlyHasDefaultSectionSetting(settings, random_string_1, random_string_2); (void)ParseNetwork(random_string_1); - try { - (void)ParseNonRFCJSONValue(random_string_1); - } catch (const std::runtime_error&) { - } (void)ParseOutputType(random_string_1); (void)RemovePrefix(random_string_1, random_string_2); (void)ResolveErrMsg(random_string_1, random_string_2); diff --git a/src/test/rpc_tests.cpp b/src/test/rpc_tests.cpp index 31c2010243..2f783a4b95 100644 --- a/src/test/rpc_tests.cpp +++ b/src/test/rpc_tests.cpp @@ -300,6 +300,7 @@ BOOST_AUTO_TEST_CASE(rpc_parse_monetary_values) BOOST_CHECK_EQUAL(AmountFromValue(ValueFromString("0.00000001000000")), 1LL); //should pass, cut trailing 0 BOOST_CHECK_THROW(AmountFromValue(ValueFromString("19e-9")), UniValue); //should fail BOOST_CHECK_EQUAL(AmountFromValue(ValueFromString("0.19e-6")), 19); //should pass, leading 0 is present + BOOST_CHECK_EXCEPTION(AmountFromValue(".19e-6"), UniValue, HasJSON(R"({"code":-3,"message":"Invalid amount"})")); //should fail, no leading 0 BOOST_CHECK_THROW(AmountFromValue(ValueFromString("92233720368.54775808")), UniValue); //overflow error BOOST_CHECK_THROW(AmountFromValue(ValueFromString("1e+11")), UniValue); //overflow error @@ -307,36 +308,6 @@ BOOST_AUTO_TEST_CASE(rpc_parse_monetary_values) BOOST_CHECK_THROW(AmountFromValue(ValueFromString("93e+9")), UniValue); //overflow error } -BOOST_AUTO_TEST_CASE(json_parse_errors) -{ - // Valid - BOOST_CHECK_EQUAL(ParseNonRFCJSONValue("1.0").get_real(), 1.0); - BOOST_CHECK_EQUAL(ParseNonRFCJSONValue("true").get_bool(), true); - BOOST_CHECK_EQUAL(ParseNonRFCJSONValue("[false]")[0].get_bool(), false); - BOOST_CHECK_EQUAL(ParseNonRFCJSONValue("{\"a\": true}")["a"].get_bool(), true); - BOOST_CHECK_EQUAL(ParseNonRFCJSONValue("{\"1\": \"true\"}")["1"].get_str(), "true"); - // Valid, with leading or trailing whitespace - BOOST_CHECK_EQUAL(ParseNonRFCJSONValue(" 1.0").get_real(), 1.0); - BOOST_CHECK_EQUAL(ParseNonRFCJSONValue("1.0 ").get_real(), 1.0); - - BOOST_CHECK_THROW(AmountFromValue(ParseNonRFCJSONValue(".19e-6")), std::runtime_error); //should fail, missing leading 0, therefore invalid JSON - BOOST_CHECK_EQUAL(AmountFromValue(ParseNonRFCJSONValue("0.00000000000000000000000000000000000001e+30 ")), 1); - // Invalid, initial garbage - BOOST_CHECK_THROW(ParseNonRFCJSONValue("[1.0"), std::runtime_error); - BOOST_CHECK_THROW(ParseNonRFCJSONValue("a1.0"), std::runtime_error); - // Invalid, trailing garbage - BOOST_CHECK_THROW(ParseNonRFCJSONValue("1.0sds"), std::runtime_error); - BOOST_CHECK_THROW(ParseNonRFCJSONValue("1.0]"), std::runtime_error); - // Invalid, keys have to be names - BOOST_CHECK_THROW(ParseNonRFCJSONValue("{1: \"true\"}"), std::runtime_error); - BOOST_CHECK_THROW(ParseNonRFCJSONValue("{true: 1}"), std::runtime_error); - BOOST_CHECK_THROW(ParseNonRFCJSONValue("{[1]: 1}"), std::runtime_error); - BOOST_CHECK_THROW(ParseNonRFCJSONValue("{{\"a\": \"a\"}: 1}"), std::runtime_error); - // BTC addresses should fail parsing - BOOST_CHECK_THROW(ParseNonRFCJSONValue("175tWpb8K1S7NmH4Zx6rewF9WQrcZv245W"), std::runtime_error); - BOOST_CHECK_THROW(ParseNonRFCJSONValue("3J98t1WpEZ73CNmQviecrnyiWrnqRhWNL"), std::runtime_error); -} - BOOST_AUTO_TEST_CASE(rpc_ban) { BOOST_CHECK_NO_THROW(CallRPC(std::string("clearbanned"))); diff --git a/src/txmempool.cpp b/src/txmempool.cpp index 1ba110d9cb..9ce4a17c5e 100644 --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -876,8 +876,17 @@ void CTxMemPool::PrioritiseTransaction(const uint256& hash, const CAmount& nFeeD } ++nTransactionsUpdated; } + if (delta == 0) { + mapDeltas.erase(hash); + LogPrintf("PrioritiseTransaction: %s (%sin mempool) delta cleared\n", hash.ToString(), it == mapTx.end() ? "not " : ""); + } else { + LogPrintf("PrioritiseTransaction: %s (%sin mempool) fee += %s, new delta=%s\n", + hash.ToString(), + it == mapTx.end() ? "not " : "", + FormatMoney(nFeeDelta), + FormatMoney(delta)); + } } - LogPrintf("PrioritiseTransaction: %s fee += %s\n", hash.ToString(), FormatMoney(nFeeDelta)); } void CTxMemPool::ApplyDelta(const uint256& hash, CAmount &nFeeDelta) const @@ -896,6 +905,22 @@ void CTxMemPool::ClearPrioritisation(const uint256& hash) mapDeltas.erase(hash); } +std::vector<CTxMemPool::delta_info> CTxMemPool::GetPrioritisedTransactions() const +{ + AssertLockNotHeld(cs); + LOCK(cs); + std::vector<delta_info> result; + result.reserve(mapDeltas.size()); + for (const auto& [txid, delta] : mapDeltas) { + const auto iter{mapTx.find(txid)}; + const bool in_mempool{iter != mapTx.end()}; + std::optional<CAmount> modified_fee; + if (in_mempool) modified_fee = iter->GetModifiedFee(); + result.emplace_back(delta_info{in_mempool, delta, modified_fee, txid}); + } + return result; +} + const CTransaction* CTxMemPool::GetConflictTx(const COutPoint& prevout) const { const auto it = mapNextTx.find(prevout); diff --git a/src/txmempool.h b/src/txmempool.h index 769b7f69ea..000033086b 100644 --- a/src/txmempool.h +++ b/src/txmempool.h @@ -516,6 +516,19 @@ public: void ApplyDelta(const uint256& hash, CAmount &nFeeDelta) const EXCLUSIVE_LOCKS_REQUIRED(cs); void ClearPrioritisation(const uint256& hash) EXCLUSIVE_LOCKS_REQUIRED(cs); + struct delta_info { + /** Whether this transaction is in the mempool. */ + const bool in_mempool; + /** The fee delta added using PrioritiseTransaction(). */ + const CAmount delta; + /** The modified fee (base fee + delta) of this entry. Only present if in_mempool=true. */ + std::optional<CAmount> modified_fee; + /** The prioritised transaction's txid. */ + const uint256 txid; + }; + /** Return a vector of all entries in mapDeltas with their corresponding delta_info. */ + std::vector<delta_info> GetPrioritisedTransactions() const EXCLUSIVE_LOCKS_REQUIRED(!cs); + /** Get the transaction in the pool that spends the same prevout */ const CTransaction* GetConflictTx(const COutPoint& prevout) const EXCLUSIVE_LOCKS_REQUIRED(cs); diff --git a/src/univalue/test/object.cpp b/src/univalue/test/object.cpp index 5ddf300393..5fb973c67b 100644 --- a/src/univalue/test/object.cpp +++ b/src/univalue/test/object.cpp @@ -412,6 +412,33 @@ void univalue_readwrite() BOOST_CHECK_EQUAL(strJson1, v.write()); + // Valid + BOOST_CHECK(v.read("1.0") && (v.get_real() == 1.0)); + BOOST_CHECK(v.read("true") && v.get_bool()); + BOOST_CHECK(v.read("[false]") && !v[0].get_bool()); + BOOST_CHECK(v.read("{\"a\": true}") && v["a"].get_bool()); + BOOST_CHECK(v.read("{\"1\": \"true\"}") && (v["1"].get_str() == "true")); + // Valid, with leading or trailing whitespace + BOOST_CHECK(v.read(" 1.0") && (v.get_real() == 1.0)); + BOOST_CHECK(v.read("1.0 ") && (v.get_real() == 1.0)); + BOOST_CHECK(v.read("0.00000000000000000000000000000000000001e+30 ") && v.get_real() == 1e-8); + + BOOST_CHECK(!v.read(".19e-6")); //should fail, missing leading 0, therefore invalid JSON + // Invalid, initial garbage + BOOST_CHECK(!v.read("[1.0")); + BOOST_CHECK(!v.read("a1.0")); + // Invalid, trailing garbage + BOOST_CHECK(!v.read("1.0sds")); + BOOST_CHECK(!v.read("1.0]")); + // Invalid, keys have to be names + BOOST_CHECK(!v.read("{1: \"true\"}")); + BOOST_CHECK(!v.read("{true: 1}")); + BOOST_CHECK(!v.read("{[1]: 1}")); + BOOST_CHECK(!v.read("{{\"a\": \"a\"}: 1}")); + // BTC addresses should fail parsing + BOOST_CHECK(!v.read("175tWpb8K1S7NmH4Zx6rewF9WQrcZv245W")); + BOOST_CHECK(!v.read("3J98t1WpEZ73CNmQviecrnyiWrnqRhWNL")); + /* Check for (correctly reporting) a parsing error if the initial JSON construct is followed by more stuff. Note that whitespace is, of course, exempt. */ diff --git a/src/wallet/bdb.cpp b/src/wallet/bdb.cpp index 6dce51fc12..68abdcd81e 100644 --- a/src/wallet/bdb.cpp +++ b/src/wallet/bdb.cpp @@ -668,7 +668,8 @@ void BerkeleyDatabase::ReloadDbEnv() env->ReloadDbEnv(); } -BerkeleyCursor::BerkeleyCursor(BerkeleyDatabase& database, const BerkeleyBatch& batch) +BerkeleyCursor::BerkeleyCursor(BerkeleyDatabase& database, const BerkeleyBatch& batch, Span<const std::byte> prefix) + : m_key_prefix(prefix.begin(), prefix.end()) { if (!database.m_db.get()) { throw std::runtime_error(STR_INTERNAL_BUG("BerkeleyDatabase does not exist")); @@ -685,19 +686,30 @@ DatabaseCursor::Status BerkeleyCursor::Next(DataStream& ssKey, DataStream& ssVal { if (m_cursor == nullptr) return Status::FAIL; // Read at cursor - SafeDbt datKey; + SafeDbt datKey(m_key_prefix.data(), m_key_prefix.size()); SafeDbt datValue; - int ret = m_cursor->get(datKey, datValue, DB_NEXT); + int ret = -1; + if (m_first && !m_key_prefix.empty()) { + ret = m_cursor->get(datKey, datValue, DB_SET_RANGE); + } else { + ret = m_cursor->get(datKey, datValue, DB_NEXT); + } + m_first = false; if (ret == DB_NOTFOUND) { return Status::DONE; } - if (ret != 0 || datKey.get_data() == nullptr || datValue.get_data() == nullptr) { + if (ret != 0) { return Status::FAIL; } + Span<const std::byte> raw_key = {AsBytePtr(datKey.get_data()), datKey.get_size()}; + if (!m_key_prefix.empty() && std::mismatch(raw_key.begin(), raw_key.end(), m_key_prefix.begin(), m_key_prefix.end()).second != m_key_prefix.end()) { + return Status::DONE; + } + // Convert to streams ssKey.clear(); - ssKey.write({AsBytePtr(datKey.get_data()), datKey.get_size()}); + ssKey.write(raw_key); ssValue.clear(); ssValue.write({AsBytePtr(datValue.get_data()), datValue.get_size()}); return Status::MORE; @@ -716,6 +728,12 @@ std::unique_ptr<DatabaseCursor> BerkeleyBatch::GetNewCursor() return std::make_unique<BerkeleyCursor>(m_database, *this); } +std::unique_ptr<DatabaseCursor> BerkeleyBatch::GetNewPrefixCursor(Span<const std::byte> prefix) +{ + if (!pdb) return nullptr; + return std::make_unique<BerkeleyCursor>(m_database, *this, prefix); +} + bool BerkeleyBatch::TxnBegin() { if (!pdb || activeTxn) @@ -777,6 +795,7 @@ bool BerkeleyBatch::ReadKey(DataStream&& key, DataStream& value) SafeDbt datValue; int ret = pdb->get(activeTxn, datKey, datValue, 0); if (ret == 0 && datValue.get_data() != nullptr) { + value.clear(); value.write({AsBytePtr(datValue.get_data()), datValue.get_size()}); return true; } diff --git a/src/wallet/bdb.h b/src/wallet/bdb.h index e8a57e8a5e..8cc03692d6 100644 --- a/src/wallet/bdb.h +++ b/src/wallet/bdb.h @@ -190,9 +190,13 @@ class BerkeleyCursor : public DatabaseCursor { private: Dbc* m_cursor; + std::vector<std::byte> m_key_prefix; + bool m_first{true}; public: - explicit BerkeleyCursor(BerkeleyDatabase& database, const BerkeleyBatch& batch); + // Constructor for cursor for records matching the prefix + // To match all records, an empty prefix may be provided. + explicit BerkeleyCursor(BerkeleyDatabase& database, const BerkeleyBatch& batch, Span<const std::byte> prefix = {}); ~BerkeleyCursor() override; Status Next(DataStream& key, DataStream& value) override; @@ -229,6 +233,7 @@ public: void Close() override; std::unique_ptr<DatabaseCursor> GetNewCursor() override; + std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) override; bool TxnBegin() override; bool TxnCommit() override; bool TxnAbort() override; diff --git a/src/wallet/db.h b/src/wallet/db.h index b4ccd13a9a..9d684225c3 100644 --- a/src/wallet/db.h +++ b/src/wallet/db.h @@ -113,6 +113,7 @@ public: virtual bool ErasePrefix(Span<const std::byte> prefix) = 0; virtual std::unique_ptr<DatabaseCursor> GetNewCursor() = 0; + virtual std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) = 0; virtual bool TxnBegin() = 0; virtual bool TxnCommit() = 0; virtual bool TxnAbort() = 0; diff --git a/src/wallet/salvage.cpp b/src/wallet/salvage.cpp index ab73e67285..e303310273 100644 --- a/src/wallet/salvage.cpp +++ b/src/wallet/salvage.cpp @@ -43,6 +43,7 @@ public: void Close() override {} std::unique_ptr<DatabaseCursor> GetNewCursor() override { return std::make_unique<DummyCursor>(); } + std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) override { return GetNewCursor(); } bool TxnBegin() override { return true; } bool TxnCommit() override { return true; } bool TxnAbort() override { return true; } diff --git a/src/wallet/sqlite.cpp b/src/wallet/sqlite.cpp index 77e8a4e9c1..8d7fe97bb1 100644 --- a/src/wallet/sqlite.cpp +++ b/src/wallet/sqlite.cpp @@ -9,6 +9,7 @@ #include <logging.h> #include <sync.h> #include <util/fs_helpers.h> +#include <util/check.h> #include <util/strencodings.h> #include <util/translation.h> #include <wallet/db.h> @@ -34,12 +35,31 @@ static void ErrorLogCallback(void* arg, int code, const char* msg) LogPrintf("SQLite Error. Code: %d. Message: %s\n", code, msg); } +static int TraceSqlCallback(unsigned code, void* context, void* param1, void* param2) +{ + auto* db = static_cast<SQLiteDatabase*>(context); + if (code == SQLITE_TRACE_STMT) { + auto* stmt = static_cast<sqlite3_stmt*>(param1); + // To be conservative and avoid leaking potentially secret information + // in the log file, only expand statements that query the database, not + // statements that update the database. + char* expanded{sqlite3_stmt_readonly(stmt) ? sqlite3_expanded_sql(stmt) : nullptr}; + LogPrintf("[%s] SQLite Statement: %s\n", db->Filename(), expanded ? expanded : sqlite3_sql(stmt)); + if (expanded) sqlite3_free(expanded); + } + return SQLITE_OK; +} + static bool BindBlobToStatement(sqlite3_stmt* stmt, int index, Span<const std::byte> blob, const std::string& description) { - int res = sqlite3_bind_blob(stmt, index, blob.data(), blob.size(), SQLITE_STATIC); + // Pass a pointer to the empty string "" below instead of passing the + // blob.data() pointer if the blob.data() pointer is null. Passing a null + // data pointer to bind_blob would cause sqlite to bind the SQL NULL value + // instead of the empty blob value X'', which would mess up SQL comparisons. + int res = sqlite3_bind_blob(stmt, index, blob.data() ? static_cast<const void*>(blob.data()) : "", blob.size(), SQLITE_STATIC); if (res != SQLITE_OK) { LogPrintf("Unable to bind %s to statement: %s\n", description, sqlite3_errstr(res)); sqlite3_clear_bindings(stmt); @@ -235,6 +255,13 @@ void SQLiteDatabase::Open() if (ret != SQLITE_OK) { throw std::runtime_error(strprintf("SQLiteDatabase: Failed to enable extended result codes: %s\n", sqlite3_errstr(ret))); } + // Trace SQL statements if tracing is enabled with -debug=walletdb -loglevel=walletdb:trace + if (LogAcceptCategory(BCLog::WALLETDB, BCLog::Level::Trace)) { + ret = sqlite3_trace_v2(m_db, SQLITE_TRACE_STMT, TraceSqlCallback, this); + if (ret != SQLITE_OK) { + LogPrintf("Failed to enable SQL tracing for %s\n", Filename()); + } + } } if (sqlite3_db_readonly(m_db, "main") != 0) { @@ -409,6 +436,7 @@ bool SQLiteBatch::ReadKey(DataStream&& key, DataStream& value) // Leftmost column in result is index 0 const std::byte* data{AsBytePtr(sqlite3_column_blob(m_read_stmt, 0))}; size_t data_size(sqlite3_column_bytes(m_read_stmt, 0)); + value.clear(); value.write({data, data_size}); sqlite3_clear_bindings(m_read_stmt); @@ -495,6 +523,9 @@ DatabaseCursor::Status SQLiteCursor::Next(DataStream& key, DataStream& value) return Status::FAIL; } + key.clear(); + value.clear(); + // Leftmost column in result is index 0 const std::byte* key_data{AsBytePtr(sqlite3_column_blob(m_cursor_stmt, 0))}; size_t key_data_size(sqlite3_column_bytes(m_cursor_stmt, 0)); @@ -507,6 +538,7 @@ DatabaseCursor::Status SQLiteCursor::Next(DataStream& key, DataStream& value) SQLiteCursor::~SQLiteCursor() { + sqlite3_clear_bindings(m_cursor_stmt); sqlite3_reset(m_cursor_stmt); int res = sqlite3_finalize(m_cursor_stmt); if (res != SQLITE_OK) { @@ -530,6 +562,48 @@ std::unique_ptr<DatabaseCursor> SQLiteBatch::GetNewCursor() return cursor; } +std::unique_ptr<DatabaseCursor> SQLiteBatch::GetNewPrefixCursor(Span<const std::byte> prefix) +{ + if (!m_database.m_db) return nullptr; + + // To get just the records we want, the SQL statement does a comparison of the binary data + // where the data must be greater than or equal to the prefix, and less than + // the prefix incremented by one (when interpreted as an integer) + std::vector<std::byte> start_range(prefix.begin(), prefix.end()); + std::vector<std::byte> end_range(prefix.begin(), prefix.end()); + auto it = end_range.rbegin(); + for (; it != end_range.rend(); ++it) { + if (*it == std::byte(std::numeric_limits<unsigned char>::max())) { + *it = std::byte(0); + continue; + } + *it = std::byte(std::to_integer<unsigned char>(*it) + 1); + break; + } + if (it == end_range.rend()) { + // If the prefix is all 0xff bytes, clear end_range as we won't need it + end_range.clear(); + } + + auto cursor = std::make_unique<SQLiteCursor>(start_range, end_range); + if (!cursor) return nullptr; + + const char* stmt_text = end_range.empty() ? "SELECT key, value FROM main WHERE key >= ?" : + "SELECT key, value FROM main WHERE key >= ? AND key < ?"; + int res = sqlite3_prepare_v2(m_database.m_db, stmt_text, -1, &cursor->m_cursor_stmt, nullptr); + if (res != SQLITE_OK) { + throw std::runtime_error(strprintf( + "SQLiteDatabase: Failed to setup cursor SQL statement: %s\n", sqlite3_errstr(res))); + } + + if (!BindBlobToStatement(cursor->m_cursor_stmt, 1, cursor->m_prefix_range_start, "prefix_start")) return nullptr; + if (!end_range.empty()) { + if (!BindBlobToStatement(cursor->m_cursor_stmt, 2, cursor->m_prefix_range_end, "prefix_end")) return nullptr; + } + + return cursor; +} + bool SQLiteBatch::TxnBegin() { if (!m_database.m_db || sqlite3_get_autocommit(m_database.m_db) == 0) return false; diff --git a/src/wallet/sqlite.h b/src/wallet/sqlite.h index d9de40569b..0378bbb8d6 100644 --- a/src/wallet/sqlite.h +++ b/src/wallet/sqlite.h @@ -15,12 +15,21 @@ struct bilingual_str; namespace wallet { class SQLiteDatabase; +/** RAII class that provides a database cursor */ class SQLiteCursor : public DatabaseCursor { public: sqlite3_stmt* m_cursor_stmt{nullptr}; + // Copies of the prefix things for the prefix cursor. + // Prevents SQLite from accessing temp variables for the prefix things. + std::vector<std::byte> m_prefix_range_start; + std::vector<std::byte> m_prefix_range_end; explicit SQLiteCursor() {} + explicit SQLiteCursor(std::vector<std::byte> start_range, std::vector<std::byte> end_range) + : m_prefix_range_start(std::move(start_range)), + m_prefix_range_end(std::move(end_range)) + {} ~SQLiteCursor() override; Status Next(DataStream& key, DataStream& value) override; @@ -57,6 +66,7 @@ public: void Close() override; std::unique_ptr<DatabaseCursor> GetNewCursor() override; + std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) override; bool TxnBegin() override; bool TxnCommit() override; bool TxnAbort() override; diff --git a/src/wallet/test/db_tests.cpp b/src/wallet/test/db_tests.cpp index 7761308bbc..4cda35ed8d 100644 --- a/src/wallet/test/db_tests.cpp +++ b/src/wallet/test/db_tests.cpp @@ -6,13 +6,56 @@ #include <test/util/setup_common.h> #include <util/fs.h> +#include <util/translation.h> +#ifdef USE_BDB #include <wallet/bdb.h> +#endif +#ifdef USE_SQLITE +#include <wallet/sqlite.h> +#endif +#include <wallet/test/util.h> +#include <wallet/walletutil.h> // for WALLET_FLAG_DESCRIPTORS #include <fstream> #include <memory> #include <string> +inline std::ostream& operator<<(std::ostream& os, const std::pair<const SerializeData, SerializeData>& kv) +{ + Span key{kv.first}, value{kv.second}; + os << "(\"" << std::string_view{reinterpret_cast<const char*>(key.data()), key.size()} << "\", \"" + << std::string_view{reinterpret_cast<const char*>(key.data()), key.size()} << "\")"; + return os; +} + namespace wallet { + +static Span<const std::byte> StringBytes(std::string_view str) +{ + return AsBytes<const char>({str.data(), str.size()}); +} + +static SerializeData StringData(std::string_view str) +{ + auto bytes = StringBytes(str); + return SerializeData{bytes.begin(), bytes.end()}; +} + +static void CheckPrefix(DatabaseBatch& batch, Span<const std::byte> prefix, MockableData expected) +{ + std::unique_ptr<DatabaseCursor> cursor = batch.GetNewPrefixCursor(prefix); + MockableData actual; + while (true) { + DataStream key, value; + DatabaseCursor::Status status = cursor->Next(key, value); + if (status == DatabaseCursor::Status::DONE) break; + BOOST_CHECK(status == DatabaseCursor::Status::MORE); + BOOST_CHECK( + actual.emplace(SerializeData(key.begin(), key.end()), SerializeData(value.begin(), value.end())).second); + } + BOOST_CHECK_EQUAL_COLLECTIONS(actual.begin(), actual.end(), expected.begin(), expected.end()); +} + BOOST_FIXTURE_TEST_SUITE(db_tests, BasicTestingSetup) static std::shared_ptr<BerkeleyEnvironment> GetWalletEnv(const fs::path& path, fs::path& database_filename) @@ -78,5 +121,90 @@ BOOST_AUTO_TEST_CASE(getwalletenv_g_dbenvs_free_instance) BOOST_CHECK(env_2_a == env_2_b); } +static std::vector<std::unique_ptr<WalletDatabase>> TestDatabases(const fs::path& path_root) +{ + std::vector<std::unique_ptr<WalletDatabase>> dbs; + DatabaseOptions options; + DatabaseStatus status; + bilingual_str error; +#ifdef USE_BDB + dbs.emplace_back(MakeBerkeleyDatabase(path_root / "bdb", options, status, error)); +#endif +#ifdef USE_SQLITE + dbs.emplace_back(MakeSQLiteDatabase(path_root / "sqlite", options, status, error)); +#endif + dbs.emplace_back(CreateMockableWalletDatabase()); + return dbs; +} + +BOOST_AUTO_TEST_CASE(db_cursor_prefix_range_test) +{ + // Test each supported db + for (const auto& database : TestDatabases(m_path_root)) { + BOOST_ASSERT(database); + + std::vector<std::string> prefixes = {"", "FIRST", "SECOND", "P\xfe\xff", "P\xff\x01", "\xff\xff"}; + + // Write elements to it + std::unique_ptr<DatabaseBatch> handler = database->MakeBatch(); + for (unsigned int i = 0; i < 10; i++) { + for (const auto& prefix : prefixes) { + BOOST_CHECK(handler->Write(std::make_pair(prefix, i), i)); + } + } + + // Now read all the items by prefix and verify that each element gets parsed correctly + for (const auto& prefix : prefixes) { + DataStream s_prefix; + s_prefix << prefix; + std::unique_ptr<DatabaseCursor> cursor = handler->GetNewPrefixCursor(s_prefix); + DataStream key; + DataStream value; + for (int i = 0; i < 10; i++) { + DatabaseCursor::Status status = cursor->Next(key, value); + BOOST_ASSERT(status == DatabaseCursor::Status::MORE); + + std::string key_back; + unsigned int i_back; + key >> key_back >> i_back; + BOOST_CHECK_EQUAL(key_back, prefix); + + unsigned int value_back; + value >> value_back; + BOOST_CHECK_EQUAL(value_back, i_back); + } + + // Let's now read it once more, it should return DONE + BOOST_CHECK(cursor->Next(key, value) == DatabaseCursor::Status::DONE); + } + } +} + +// Lower level DatabaseBase::GetNewPrefixCursor test, to cover cases that aren't +// covered in the higher level test above. The higher level test uses +// serialized strings which are prefixed with string length, so it doesn't test +// truly empty prefixes or prefixes that begin with \xff +BOOST_AUTO_TEST_CASE(db_cursor_prefix_byte_test) +{ + const MockableData::value_type + e{StringData(""), StringData("e")}, + p{StringData("prefix"), StringData("p")}, + ps{StringData("prefixsuffix"), StringData("ps")}, + f{StringData("\xff"), StringData("f")}, + fs{StringData("\xffsuffix"), StringData("fs")}, + ff{StringData("\xff\xff"), StringData("ff")}, + ffs{StringData("\xff\xffsuffix"), StringData("ffs")}; + for (const auto& database : TestDatabases(m_path_root)) { + std::unique_ptr<DatabaseBatch> batch = database->MakeBatch(); + for (const auto& [k, v] : {e, p, ps, f, fs, ff, ffs}) { + batch->Write(MakeUCharSpan(k), MakeUCharSpan(v)); + } + CheckPrefix(*batch, StringBytes(""), {e, p, ps, f, fs, ff, ffs}); + CheckPrefix(*batch, StringBytes("prefix"), {p, ps}); + CheckPrefix(*batch, StringBytes("\xff"), {f, fs, ff, ffs}); + CheckPrefix(*batch, StringBytes("\xff\xff"), {ff, ffs}); + } +} + BOOST_AUTO_TEST_SUITE_END() } // namespace wallet diff --git a/src/wallet/test/util.cpp b/src/wallet/test/util.cpp index eacb70cd69..069ab25f26 100644 --- a/src/wallet/test/util.cpp +++ b/src/wallet/test/util.cpp @@ -92,6 +92,17 @@ CTxDestination getNewDestination(CWallet& w, OutputType output_type) return *Assert(w.GetNewDestination(output_type, "")); } +// BytePrefix compares equality with other byte spans that begin with the same prefix. +struct BytePrefix { Span<const std::byte> prefix; }; +bool operator<(BytePrefix a, Span<const std::byte> b) { return a.prefix < b.subspan(0, std::min(a.prefix.size(), b.size())); } +bool operator<(Span<const std::byte> a, BytePrefix b) { return a.subspan(0, std::min(a.size(), b.prefix.size())) < b.prefix; } + +MockableCursor::MockableCursor(const MockableData& records, bool pass, Span<const std::byte> prefix) +{ + m_pass = pass; + std::tie(m_cursor, m_cursor_end) = records.equal_range(BytePrefix{prefix}); +} + DatabaseCursor::Status MockableCursor::Next(DataStream& key, DataStream& value) { if (!m_pass) { @@ -100,6 +111,8 @@ DatabaseCursor::Status MockableCursor::Next(DataStream& key, DataStream& value) if (m_cursor == m_cursor_end) { return Status::DONE; } + key.clear(); + value.clear(); const auto& [key_data, value_data] = *m_cursor; key.write(key_data); value.write(value_data); @@ -117,6 +130,7 @@ bool MockableBatch::ReadKey(DataStream&& key, DataStream& value) if (it == m_records.end()) { return false; } + value.clear(); value.write(it->second); return true; } @@ -172,7 +186,7 @@ bool MockableBatch::ErasePrefix(Span<const std::byte> prefix) return true; } -std::unique_ptr<WalletDatabase> CreateMockableWalletDatabase(std::map<SerializeData, SerializeData> records) +std::unique_ptr<WalletDatabase> CreateMockableWalletDatabase(MockableData records) { return std::make_unique<MockableDatabase>(records); } diff --git a/src/wallet/test/util.h b/src/wallet/test/util.h index b1ad1c959b..2a1fe639de 100644 --- a/src/wallet/test/util.h +++ b/src/wallet/test/util.h @@ -48,14 +48,17 @@ std::string getnewaddress(CWallet& w); /** Returns a new destination, of an specific type, from the wallet */ CTxDestination getNewDestination(CWallet& w, OutputType output_type); +using MockableData = std::map<SerializeData, SerializeData, std::less<>>; + class MockableCursor: public DatabaseCursor { public: - std::map<SerializeData, SerializeData>::const_iterator m_cursor; - std::map<SerializeData, SerializeData>::const_iterator m_cursor_end; + MockableData::const_iterator m_cursor; + MockableData::const_iterator m_cursor_end; bool m_pass; - explicit MockableCursor(const std::map<SerializeData, SerializeData>& records, bool pass) : m_cursor(records.begin()), m_cursor_end(records.end()), m_pass(pass) {} + explicit MockableCursor(const MockableData& records, bool pass) : m_cursor(records.begin()), m_cursor_end(records.end()), m_pass(pass) {} + MockableCursor(const MockableData& records, bool pass, Span<const std::byte> prefix); ~MockableCursor() {} Status Next(DataStream& key, DataStream& value) override; @@ -64,7 +67,7 @@ public: class MockableBatch : public DatabaseBatch { private: - std::map<SerializeData, SerializeData>& m_records; + MockableData& m_records; bool m_pass; bool ReadKey(DataStream&& key, DataStream& value) override; @@ -74,7 +77,7 @@ private: bool ErasePrefix(Span<const std::byte> prefix) override; public: - explicit MockableBatch(std::map<SerializeData, SerializeData>& records, bool pass) : m_records(records), m_pass(pass) {} + explicit MockableBatch(MockableData& records, bool pass) : m_records(records), m_pass(pass) {} ~MockableBatch() {} void Flush() override {} @@ -84,6 +87,9 @@ public: { return std::make_unique<MockableCursor>(m_records, m_pass); } + std::unique_ptr<DatabaseCursor> GetNewPrefixCursor(Span<const std::byte> prefix) override { + return std::make_unique<MockableCursor>(m_records, m_pass, prefix); + } bool TxnBegin() override { return m_pass; } bool TxnCommit() override { return m_pass; } bool TxnAbort() override { return m_pass; } @@ -94,10 +100,10 @@ public: class MockableDatabase : public WalletDatabase { public: - std::map<SerializeData, SerializeData> m_records; + MockableData m_records; bool m_pass{true}; - MockableDatabase(std::map<SerializeData, SerializeData> records = {}) : WalletDatabase(), m_records(records) {} + MockableDatabase(MockableData records = {}) : WalletDatabase(), m_records(records) {} ~MockableDatabase() {}; void Open() override {} @@ -117,7 +123,7 @@ public: std::unique_ptr<DatabaseBatch> MakeBatch(bool flush_on_close = true) override { return std::make_unique<MockableBatch>(m_records, m_pass); } }; -std::unique_ptr<WalletDatabase> CreateMockableWalletDatabase(std::map<SerializeData, SerializeData> records = {}); +std::unique_ptr<WalletDatabase> CreateMockableWalletDatabase(MockableData records = {}); MockableDatabase& GetMockableDatabase(CWallet& wallet); } // namespace wallet diff --git a/src/wallet/test/walletload_tests.cpp b/src/wallet/test/walletload_tests.cpp index 6823eafdfa..c1ff7baae1 100644 --- a/src/wallet/test/walletload_tests.cpp +++ b/src/wallet/test/walletload_tests.cpp @@ -83,7 +83,7 @@ BOOST_FIXTURE_TEST_CASE(wallet_load_ckey, TestingSetup) { SerializeData ckey_record_key; SerializeData ckey_record_value; - std::map<SerializeData, SerializeData> records; + MockableData records; { // Context setup. |