diff options
Diffstat (limited to 'src/script')
-rw-r--r-- | src/script/descriptor.cpp | 251 | ||||
-rw-r--r-- | src/script/descriptor.h | 51 |
2 files changed, 216 insertions, 86 deletions
diff --git a/src/script/descriptor.cpp b/src/script/descriptor.cpp index b1d9a5bda7..83dc046ca1 100644 --- a/src/script/descriptor.cpp +++ b/src/script/descriptor.cpp @@ -150,10 +150,22 @@ typedef std::vector<uint32_t> KeyPath; /** Interface for public key objects in descriptors. */ struct PubkeyProvider { +protected: + //! Index of this key expression in the descriptor + //! E.g. If this PubkeyProvider is key1 in multi(2, key1, key2, key3), then m_expr_index = 0 + uint32_t m_expr_index; + +public: + PubkeyProvider(uint32_t exp_index) : m_expr_index(exp_index) {} + virtual ~PubkeyProvider() = default; - /** Derive a public key. If key==nullptr, only info is desired. */ - virtual bool GetPubKey(int pos, const SigningProvider& arg, CPubKey* key, KeyOriginInfo& info) const = 0; + /** Derive a public key. + * read_cache is the cache to read keys from (if not nullptr) + * write_cache is the cache to write keys to (if not nullptr) + * Caches are not exclusive but this is not tested. Currently we use them exclusively + */ + virtual bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) = 0; /** Whether this represent multiple public keys at different positions. */ virtual bool IsRange() const = 0; @@ -182,10 +194,10 @@ class OriginPubkeyProvider final : public PubkeyProvider } public: - OriginPubkeyProvider(KeyOriginInfo info, std::unique_ptr<PubkeyProvider> provider) : m_origin(std::move(info)), m_provider(std::move(provider)) {} - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey* key, KeyOriginInfo& info) const override + OriginPubkeyProvider(uint32_t exp_index, KeyOriginInfo info, std::unique_ptr<PubkeyProvider> provider) : PubkeyProvider(exp_index), m_origin(std::move(info)), m_provider(std::move(provider)) {} + bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) override { - if (!m_provider->GetPubKey(pos, arg, key, info)) return false; + if (!m_provider->GetPubKey(pos, arg, key, info, read_cache, write_cache)) return false; std::copy(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint), info.fingerprint); info.path.insert(info.path.begin(), m_origin.path.begin(), m_origin.path.end()); return true; @@ -212,10 +224,10 @@ class ConstPubkeyProvider final : public PubkeyProvider CPubKey m_pubkey; public: - ConstPubkeyProvider(const CPubKey& pubkey) : m_pubkey(pubkey) {} - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey* key, KeyOriginInfo& info) const override + ConstPubkeyProvider(uint32_t exp_index, const CPubKey& pubkey) : PubkeyProvider(exp_index), m_pubkey(pubkey) {} + bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) override { - if (key) *key = m_pubkey; + key = m_pubkey; info.path.clear(); CKeyID keyid = m_pubkey.GetID(); std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), info.fingerprint); @@ -246,22 +258,36 @@ enum class DeriveType { /** An object representing a parsed extended public key in a descriptor. */ class BIP32PubkeyProvider final : public PubkeyProvider { - CExtPubKey m_extkey; + // Root xpub, path, and final derivation step type being used, if any + CExtPubKey m_root_extkey; KeyPath m_path; DeriveType m_derive; + // Cache of the parent of the final derived pubkeys. + // Primarily useful for situations when no read_cache is provided + CExtPubKey m_cached_xpub; bool GetExtKey(const SigningProvider& arg, CExtKey& ret) const { CKey key; - if (!arg.GetKey(m_extkey.pubkey.GetID(), key)) return false; - ret.nDepth = m_extkey.nDepth; - std::copy(m_extkey.vchFingerprint, m_extkey.vchFingerprint + sizeof(ret.vchFingerprint), ret.vchFingerprint); - ret.nChild = m_extkey.nChild; - ret.chaincode = m_extkey.chaincode; + if (!arg.GetKey(m_root_extkey.pubkey.GetID(), key)) return false; + ret.nDepth = m_root_extkey.nDepth; + std::copy(m_root_extkey.vchFingerprint, m_root_extkey.vchFingerprint + sizeof(ret.vchFingerprint), ret.vchFingerprint); + ret.nChild = m_root_extkey.nChild; + ret.chaincode = m_root_extkey.chaincode; ret.key = key; return true; } + // Derives the last xprv + bool GetDerivedExtKey(const SigningProvider& arg, CExtKey& xprv) const + { + if (!GetExtKey(arg, xprv)) return false; + for (auto entry : m_path) { + xprv.Derive(xprv, entry); + } + return true; + } + bool IsHardened() const { if (m_derive == DeriveType::HARDENED) return true; @@ -272,37 +298,77 @@ class BIP32PubkeyProvider final : public PubkeyProvider } public: - BIP32PubkeyProvider(const CExtPubKey& extkey, KeyPath path, DeriveType derive) : m_extkey(extkey), m_path(std::move(path)), m_derive(derive) {} + BIP32PubkeyProvider(uint32_t exp_index, const CExtPubKey& extkey, KeyPath path, DeriveType derive) : PubkeyProvider(exp_index), m_root_extkey(extkey), m_path(std::move(path)), m_derive(derive) {} bool IsRange() const override { return m_derive != DeriveType::NO; } size_t GetSize() const override { return 33; } - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey* key, KeyOriginInfo& info) const override + bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key_out, KeyOriginInfo& final_info_out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) override { - if (key) { - if (IsHardened()) { - CKey priv_key; - if (!GetPrivKey(pos, arg, priv_key)) return false; - *key = priv_key.GetPubKey(); - } else { - // TODO: optimize by caching - CExtPubKey extkey = m_extkey; - for (auto entry : m_path) { - extkey.Derive(extkey, entry); - } - if (m_derive == DeriveType::UNHARDENED) extkey.Derive(extkey, pos); - assert(m_derive != DeriveType::HARDENED); - *key = extkey.pubkey; + // Info of parent of the to be derived pubkey + KeyOriginInfo parent_info; + CKeyID keyid = m_root_extkey.pubkey.GetID(); + std::copy(keyid.begin(), keyid.begin() + sizeof(parent_info.fingerprint), parent_info.fingerprint); + parent_info.path = m_path; + + // Info of the derived key itself which is copied out upon successful completion + KeyOriginInfo final_info_out_tmp = parent_info; + if (m_derive == DeriveType::UNHARDENED) final_info_out_tmp.path.push_back((uint32_t)pos); + if (m_derive == DeriveType::HARDENED) final_info_out_tmp.path.push_back(((uint32_t)pos) | 0x80000000L); + + // Derive keys or fetch them from cache + CExtPubKey final_extkey = m_root_extkey; + CExtPubKey parent_extkey = m_root_extkey; + bool der = true; + if (read_cache) { + if (!read_cache->GetCachedDerivedExtPubKey(m_expr_index, pos, final_extkey)) { + if (m_derive == DeriveType::HARDENED) return false; + // Try to get the derivation parent + if (!read_cache->GetCachedParentExtPubKey(m_expr_index, parent_extkey)) return false; + final_extkey = parent_extkey; + if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos); + } + } else if (m_cached_xpub.pubkey.IsValid() && m_derive != DeriveType::HARDENED) { + parent_extkey = final_extkey = m_cached_xpub; + if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos); + } else if (IsHardened()) { + CExtKey xprv; + if (!GetDerivedExtKey(arg, xprv)) return false; + parent_extkey = xprv.Neuter(); + if (m_derive == DeriveType::UNHARDENED) der = xprv.Derive(xprv, pos); + if (m_derive == DeriveType::HARDENED) der = xprv.Derive(xprv, pos | 0x80000000UL); + final_extkey = xprv.Neuter(); + } else { + for (auto entry : m_path) { + der = parent_extkey.Derive(parent_extkey, entry); + assert(der); } + final_extkey = parent_extkey; + if (m_derive == DeriveType::UNHARDENED) der = parent_extkey.Derive(final_extkey, pos); + assert(m_derive != DeriveType::HARDENED); } - CKeyID keyid = m_extkey.pubkey.GetID(); - std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), info.fingerprint); - info.path = m_path; - if (m_derive == DeriveType::UNHARDENED) info.path.push_back((uint32_t)pos); - if (m_derive == DeriveType::HARDENED) info.path.push_back(((uint32_t)pos) | 0x80000000L); + assert(der); + + final_info_out = final_info_out_tmp; + key_out = final_extkey.pubkey; + + // We rely on the consumer to check that m_derive isn't HARDENED as above + // But we can't have already cached something in case we read something from the cache + // and parent_extkey isn't actually the parent. + if (!m_cached_xpub.pubkey.IsValid()) m_cached_xpub = parent_extkey; + + if (write_cache) { + // Only cache parent if there is any unhardened derivation + if (m_derive != DeriveType::HARDENED) { + write_cache->CacheParentExtPubKey(m_expr_index, parent_extkey); + } else if (final_info_out.path.size() > 0) { + write_cache->CacheDerivedExtPubKey(m_expr_index, pos, final_extkey); + } + } + return true; } std::string ToString() const override { - std::string ret = EncodeExtPubKey(m_extkey) + FormatHDKeypath(m_path); + std::string ret = EncodeExtPubKey(m_root_extkey) + FormatHDKeypath(m_path); if (IsRange()) { ret += "/*"; if (m_derive == DeriveType::HARDENED) ret += '\''; @@ -323,10 +389,7 @@ public: bool GetPrivKey(int pos, const SigningProvider& arg, CKey& key) const override { CExtKey extkey; - if (!GetExtKey(arg, extkey)) return false; - for (auto entry : m_path) { - extkey.Derive(extkey, entry); - } + if (!GetDerivedExtKey(arg, extkey)) return false; if (m_derive == DeriveType::UNHARDENED) extkey.Derive(extkey, pos); if (m_derive == DeriveType::HARDENED) extkey.Derive(extkey, pos | 0x80000000UL); key = extkey.key; @@ -425,7 +488,7 @@ public: return ret; } - bool ExpandHelper(int pos, const SigningProvider& arg, Span<const unsigned char>* cache_read, std::vector<CScript>& output_scripts, FlatSigningProvider& out, std::vector<unsigned char>* cache_write) const + bool ExpandHelper(int pos, const SigningProvider& arg, const DescriptorCache* read_cache, std::vector<CScript>& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache) const { std::vector<std::pair<CPubKey, KeyOriginInfo>> entries; entries.reserve(m_pubkey_args.size()); @@ -433,27 +496,12 @@ public: // Construct temporary data in `entries` and `subscripts`, to avoid producing output in case of failure. for (const auto& p : m_pubkey_args) { entries.emplace_back(); - // If we have a cache, we don't need GetPubKey to compute the public key. - // Pass in nullptr to signify only origin info is desired. - if (!p->GetPubKey(pos, arg, cache_read ? nullptr : &entries.back().first, entries.back().second)) return false; - if (cache_read) { - // Cached expanded public key exists, use it. - if (cache_read->size() == 0) return false; - bool compressed = ((*cache_read)[0] == 0x02 || (*cache_read)[0] == 0x03) && cache_read->size() >= 33; - bool uncompressed = ((*cache_read)[0] == 0x04) && cache_read->size() >= 65; - if (!(compressed || uncompressed)) return false; - CPubKey pubkey(cache_read->begin(), cache_read->begin() + (compressed ? 33 : 65)); - entries.back().first = pubkey; - *cache_read = cache_read->subspan(compressed ? 33 : 65); - } - if (cache_write) { - cache_write->insert(cache_write->end(), entries.back().first.begin(), entries.back().first.end()); - } + if (!p->GetPubKey(pos, arg, entries.back().first, entries.back().second, read_cache, write_cache)) return false; } std::vector<CScript> subscripts; if (m_subdescriptor_arg) { FlatSigningProvider subprovider; - if (!m_subdescriptor_arg->ExpandHelper(pos, arg, cache_read, subscripts, subprovider, cache_write)) return false; + if (!m_subdescriptor_arg->ExpandHelper(pos, arg, read_cache, subscripts, subprovider, write_cache)) return false; out = Merge(out, subprovider); } @@ -477,15 +525,14 @@ public: return true; } - bool Expand(int pos, const SigningProvider& provider, std::vector<CScript>& output_scripts, FlatSigningProvider& out, std::vector<unsigned char>* cache = nullptr) const final + bool Expand(int pos, const SigningProvider& provider, std::vector<CScript>& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache = nullptr) const final { - return ExpandHelper(pos, provider, nullptr, output_scripts, out, cache); + return ExpandHelper(pos, provider, nullptr, output_scripts, out, write_cache); } - bool ExpandFromCache(int pos, const std::vector<unsigned char>& cache, std::vector<CScript>& output_scripts, FlatSigningProvider& out) const final + bool ExpandFromCache(int pos, const DescriptorCache& read_cache, std::vector<CScript>& output_scripts, FlatSigningProvider& out) const final { - Span<const unsigned char> span = MakeSpan(cache); - return ExpandHelper(pos, DUMMY_SIGNING_PROVIDER, &span, output_scripts, out, nullptr) && span.size() == 0; + return ExpandHelper(pos, DUMMY_SIGNING_PROVIDER, &read_cache, output_scripts, out, nullptr); } void ExpandPrivate(int pos, const SigningProvider& provider, FlatSigningProvider& out) const final @@ -698,7 +745,7 @@ NODISCARD bool ParseKeyPath(const std::vector<Span<const char>>& split, KeyPath& } /** Parse a public key that excludes origin information. */ -std::unique_ptr<PubkeyProvider> ParsePubkeyInner(const Span<const char>& sp, bool permit_uncompressed, FlatSigningProvider& out, std::string& error) +std::unique_ptr<PubkeyProvider> ParsePubkeyInner(uint32_t key_exp_index, const Span<const char>& sp, bool permit_uncompressed, FlatSigningProvider& out, std::string& error) { using namespace spanparsing; @@ -714,7 +761,7 @@ std::unique_ptr<PubkeyProvider> ParsePubkeyInner(const Span<const char>& sp, boo CPubKey pubkey(data); if (pubkey.IsFullyValid()) { if (permit_uncompressed || pubkey.IsCompressed()) { - return MakeUnique<ConstPubkeyProvider>(pubkey); + return MakeUnique<ConstPubkeyProvider>(key_exp_index, pubkey); } else { error = "Uncompressed keys are not allowed"; return nullptr; @@ -728,7 +775,7 @@ std::unique_ptr<PubkeyProvider> ParsePubkeyInner(const Span<const char>& sp, boo if (permit_uncompressed || key.IsCompressed()) { CPubKey pubkey = key.GetPubKey(); out.keys.emplace(pubkey.GetID(), key); - return MakeUnique<ConstPubkeyProvider>(pubkey); + return MakeUnique<ConstPubkeyProvider>(key_exp_index, pubkey); } else { error = "Uncompressed keys are not allowed"; return nullptr; @@ -755,11 +802,11 @@ std::unique_ptr<PubkeyProvider> ParsePubkeyInner(const Span<const char>& sp, boo extpubkey = extkey.Neuter(); out.keys.emplace(extpubkey.pubkey.GetID(), extkey.key); } - return MakeUnique<BIP32PubkeyProvider>(extpubkey, std::move(path), type); + return MakeUnique<BIP32PubkeyProvider>(key_exp_index, extpubkey, std::move(path), type); } /** Parse a public key including origin information (if enabled). */ -std::unique_ptr<PubkeyProvider> ParsePubkey(const Span<const char>& sp, bool permit_uncompressed, FlatSigningProvider& out, std::string& error) +std::unique_ptr<PubkeyProvider> ParsePubkey(uint32_t key_exp_index, const Span<const char>& sp, bool permit_uncompressed, FlatSigningProvider& out, std::string& error) { using namespace spanparsing; @@ -768,7 +815,7 @@ std::unique_ptr<PubkeyProvider> ParsePubkey(const Span<const char>& sp, bool per error = "Multiple ']' characters found for a single pubkey"; return nullptr; } - if (origin_split.size() == 1) return ParsePubkeyInner(origin_split[0], permit_uncompressed, out, error); + if (origin_split.size() == 1) return ParsePubkeyInner(key_exp_index, origin_split[0], permit_uncompressed, out, error); if (origin_split[0].size() < 1 || origin_split[0][0] != '[') { error = strprintf("Key origin start '[ character expected but not found, got '%c' instead", origin_split[0][0]); return nullptr; @@ -789,30 +836,30 @@ std::unique_ptr<PubkeyProvider> ParsePubkey(const Span<const char>& sp, bool per assert(fpr_bytes.size() == 4); std::copy(fpr_bytes.begin(), fpr_bytes.end(), info.fingerprint); if (!ParseKeyPath(slash_split, info.path, error)) return nullptr; - auto provider = ParsePubkeyInner(origin_split[1], permit_uncompressed, out, error); + auto provider = ParsePubkeyInner(key_exp_index, origin_split[1], permit_uncompressed, out, error); if (!provider) return nullptr; - return MakeUnique<OriginPubkeyProvider>(std::move(info), std::move(provider)); + return MakeUnique<OriginPubkeyProvider>(key_exp_index, std::move(info), std::move(provider)); } /** Parse a script in a particular context. */ -std::unique_ptr<DescriptorImpl> ParseScript(Span<const char>& sp, ParseScriptContext ctx, FlatSigningProvider& out, std::string& error) +std::unique_ptr<DescriptorImpl> ParseScript(uint32_t key_exp_index, Span<const char>& sp, ParseScriptContext ctx, FlatSigningProvider& out, std::string& error) { using namespace spanparsing; auto expr = Expr(sp); bool sorted_multi = false; if (Func("pk", expr)) { - auto pubkey = ParsePubkey(expr, ctx != ParseScriptContext::P2WSH, out, error); + auto pubkey = ParsePubkey(key_exp_index, expr, ctx != ParseScriptContext::P2WSH, out, error); if (!pubkey) return nullptr; return MakeUnique<PKDescriptor>(std::move(pubkey)); } if (Func("pkh", expr)) { - auto pubkey = ParsePubkey(expr, ctx != ParseScriptContext::P2WSH, out, error); + auto pubkey = ParsePubkey(key_exp_index, expr, ctx != ParseScriptContext::P2WSH, out, error); if (!pubkey) return nullptr; return MakeUnique<PKHDescriptor>(std::move(pubkey)); } if (ctx == ParseScriptContext::TOP && Func("combo", expr)) { - auto pubkey = ParsePubkey(expr, true, out, error); + auto pubkey = ParsePubkey(key_exp_index, expr, true, out, error); if (!pubkey) return nullptr; return MakeUnique<ComboDescriptor>(std::move(pubkey)); } else if (ctx != ParseScriptContext::TOP && Func("combo", expr)) { @@ -834,10 +881,11 @@ std::unique_ptr<DescriptorImpl> ParseScript(Span<const char>& sp, ParseScriptCon return nullptr; } auto arg = Expr(expr); - auto pk = ParsePubkey(arg, ctx != ParseScriptContext::P2WSH, out, error); + auto pk = ParsePubkey(key_exp_index, arg, ctx != ParseScriptContext::P2WSH, out, error); if (!pk) return nullptr; script_size += pk->GetSize() + 1; providers.emplace_back(std::move(pk)); + key_exp_index++; } if (providers.size() < 1 || providers.size() > 16) { error = strprintf("Cannot have %u keys in multisig; must have between 1 and 16 keys, inclusive", providers.size()); @@ -864,7 +912,7 @@ std::unique_ptr<DescriptorImpl> ParseScript(Span<const char>& sp, ParseScriptCon return MakeUnique<MultisigDescriptor>(thres, std::move(providers), sorted_multi); } if (ctx != ParseScriptContext::P2WSH && Func("wpkh", expr)) { - auto pubkey = ParsePubkey(expr, false, out, error); + auto pubkey = ParsePubkey(key_exp_index, expr, false, out, error); if (!pubkey) return nullptr; return MakeUnique<WPKHDescriptor>(std::move(pubkey)); } else if (ctx == ParseScriptContext::P2WSH && Func("wpkh", expr)) { @@ -872,7 +920,7 @@ std::unique_ptr<DescriptorImpl> ParseScript(Span<const char>& sp, ParseScriptCon return nullptr; } if (ctx == ParseScriptContext::TOP && Func("sh", expr)) { - auto desc = ParseScript(expr, ParseScriptContext::P2SH, out, error); + auto desc = ParseScript(key_exp_index, expr, ParseScriptContext::P2SH, out, error); if (!desc || expr.size()) return nullptr; return MakeUnique<SHDescriptor>(std::move(desc)); } else if (ctx != ParseScriptContext::TOP && Func("sh", expr)) { @@ -880,7 +928,7 @@ std::unique_ptr<DescriptorImpl> ParseScript(Span<const char>& sp, ParseScriptCon return nullptr; } if (ctx != ParseScriptContext::P2WSH && Func("wsh", expr)) { - auto desc = ParseScript(expr, ParseScriptContext::P2WSH, out, error); + auto desc = ParseScript(key_exp_index, expr, ParseScriptContext::P2WSH, out, error); if (!desc || expr.size()) return nullptr; return MakeUnique<WSHDescriptor>(std::move(desc)); } else if (ctx == ParseScriptContext::P2WSH && Func("wsh", expr)) { @@ -917,10 +965,10 @@ std::unique_ptr<DescriptorImpl> ParseScript(Span<const char>& sp, ParseScriptCon std::unique_ptr<PubkeyProvider> InferPubkey(const CPubKey& pubkey, ParseScriptContext, const SigningProvider& provider) { - std::unique_ptr<PubkeyProvider> key_provider = MakeUnique<ConstPubkeyProvider>(pubkey); + std::unique_ptr<PubkeyProvider> key_provider = MakeUnique<ConstPubkeyProvider>(0, pubkey); KeyOriginInfo info; if (provider.GetKeyOrigin(pubkey.GetID(), info)) { - return MakeUnique<OriginPubkeyProvider>(std::move(info), std::move(key_provider)); + return MakeUnique<OriginPubkeyProvider>(0, std::move(info), std::move(key_provider)); } return key_provider; } @@ -1032,7 +1080,7 @@ std::unique_ptr<Descriptor> Parse(const std::string& descriptor, FlatSigningProv { Span<const char> sp(descriptor.data(), descriptor.size()); if (!CheckChecksum(sp, require_checksum, error)) return nullptr; - auto ret = ParseScript(sp, ParseScriptContext::TOP, out, error); + auto ret = ParseScript(0, sp, ParseScriptContext::TOP, out, error); if (sp.size() == 0 && ret) return std::unique_ptr<Descriptor>(std::move(ret)); return nullptr; } @@ -1050,3 +1098,42 @@ std::unique_ptr<Descriptor> InferDescriptor(const CScript& script, const Signing { return InferScript(script, ParseScriptContext::TOP, provider); } + +void DescriptorCache::CacheParentExtPubKey(uint32_t key_exp_pos, const CExtPubKey& xpub) +{ + m_parent_xpubs[key_exp_pos] = xpub; +} + +void DescriptorCache::CacheDerivedExtPubKey(uint32_t key_exp_pos, uint32_t der_index, const CExtPubKey& xpub) +{ + auto& xpubs = m_derived_xpubs[key_exp_pos]; + xpubs[der_index] = xpub; +} + +bool DescriptorCache::GetCachedParentExtPubKey(uint32_t key_exp_pos, CExtPubKey& xpub) const +{ + const auto& it = m_parent_xpubs.find(key_exp_pos); + if (it == m_parent_xpubs.end()) return false; + xpub = it->second; + return true; +} + +bool DescriptorCache::GetCachedDerivedExtPubKey(uint32_t key_exp_pos, uint32_t der_index, CExtPubKey& xpub) const +{ + const auto& key_exp_it = m_derived_xpubs.find(key_exp_pos); + if (key_exp_it == m_derived_xpubs.end()) return false; + const auto& der_it = key_exp_it->second.find(der_index); + if (der_it == key_exp_it->second.end()) return false; + xpub = der_it->second; + return true; +} + +const ExtPubKeyMap DescriptorCache::GetCachedParentExtPubKeys() const +{ + return m_parent_xpubs; +} + +const std::unordered_map<uint32_t, ExtPubKeyMap> DescriptorCache::GetCachedDerivedExtPubKeys() const +{ + return m_derived_xpubs; +} diff --git a/src/script/descriptor.h b/src/script/descriptor.h index 58b920c681..34cd5760de 100644 --- a/src/script/descriptor.h +++ b/src/script/descriptor.h @@ -13,6 +13,49 @@ #include <vector> +using ExtPubKeyMap = std::unordered_map<uint32_t, CExtPubKey>; + +/** Cache for single descriptor's derived extended pubkeys */ +class DescriptorCache { +private: + /** Map key expression index -> map of (key derivation index -> xpub) */ + std::unordered_map<uint32_t, ExtPubKeyMap> m_derived_xpubs; + /** Map key expression index -> parent xpub */ + ExtPubKeyMap m_parent_xpubs; + +public: + /** Cache a parent xpub + * + * @param[in] key_exp_pos Position of the key expression within the descriptor + * @param[in] xpub The CExtPubKey to cache + */ + void CacheParentExtPubKey(uint32_t key_exp_pos, const CExtPubKey& xpub); + /** Retrieve a cached parent xpub + * + * @param[in] key_exp_pos Position of the key expression within the descriptor + * @param[in] xpub The CExtPubKey to get from cache + */ + bool GetCachedParentExtPubKey(uint32_t key_exp_pos, CExtPubKey& xpub) const; + /** Cache an xpub derived at an index + * + * @param[in] key_exp_pos Position of the key expression within the descriptor + * @param[in] der_index Derivation index of the xpub + * @param[in] xpub The CExtPubKey to cache + */ + void CacheDerivedExtPubKey(uint32_t key_exp_pos, uint32_t der_index, const CExtPubKey& xpub); + /** Retrieve a cached xpub derived at an index + * + * @param[in] key_exp_pos Position of the key expression within the descriptor + * @param[in] der_index Derivation index of the xpub + * @param[in] xpub The CExtPubKey to get from cache + */ + bool GetCachedDerivedExtPubKey(uint32_t key_exp_pos, uint32_t der_index, CExtPubKey& xpub) const; + + /** Retrieve all cached parent xpubs */ + const ExtPubKeyMap GetCachedParentExtPubKeys() const; + /** Retrieve all cached derived xpubs */ + const std::unordered_map<uint32_t, ExtPubKeyMap> GetCachedDerivedExtPubKeys() const; +}; /** \brief Interface for parsed descriptor objects. * @@ -53,18 +96,18 @@ struct Descriptor { * @param[in] provider The provider to query for private keys in case of hardened derivation. * @param[out] output_scripts The expanded scriptPubKeys. * @param[out] out Scripts and public keys necessary for solving the expanded scriptPubKeys (may be equal to `provider`). - * @param[out] cache Cache data necessary to evaluate the descriptor at this point without access to private keys. + * @param[out] write_cache Cache data necessary to evaluate the descriptor at this point without access to private keys. */ - virtual bool Expand(int pos, const SigningProvider& provider, std::vector<CScript>& output_scripts, FlatSigningProvider& out, std::vector<unsigned char>* cache = nullptr) const = 0; + virtual bool Expand(int pos, const SigningProvider& provider, std::vector<CScript>& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache = nullptr) const = 0; /** Expand a descriptor at a specified position using cached expansion data. * * @param[in] pos The position at which to expand the descriptor. If IsRange() is false, this is ignored. - * @param[in] cache Cached expansion data. + * @param[in] read_cache Cached expansion data. * @param[out] output_scripts The expanded scriptPubKeys. * @param[out] out Scripts and public keys necessary for solving the expanded scriptPubKeys (may be equal to `provider`). */ - virtual bool ExpandFromCache(int pos, const std::vector<unsigned char>& cache, std::vector<CScript>& output_scripts, FlatSigningProvider& out) const = 0; + virtual bool ExpandFromCache(int pos, const DescriptorCache& read_cache, std::vector<CScript>& output_scripts, FlatSigningProvider& out) const = 0; /** Expand the private key for a descriptor at a specified position, if possible. * |