diff options
-rw-r--r-- | src/psbt.h | 41 | ||||
-rw-r--r-- | src/script/sign.h | 2 | ||||
-rw-r--r-- | src/serialize.h | 13 | ||||
-rw-r--r-- | src/test/fuzz/script_sign.cpp | 4 |
4 files changed, 38 insertions, 22 deletions
diff --git a/src/psbt.h b/src/psbt.h index 21daa050ea..a752e99e74 100644 --- a/src/psbt.h +++ b/src/psbt.h @@ -69,52 +69,52 @@ struct PSBTInput inline void Serialize(Stream& s) const { // Write the utxo if (non_witness_utxo) { - SerializeToVector(s, PSBT_IN_NON_WITNESS_UTXO); + SerializeToVector(s, CompactSizeWriter(PSBT_IN_NON_WITNESS_UTXO)); OverrideStream<Stream> os(&s, s.GetType(), s.GetVersion() | SERIALIZE_TRANSACTION_NO_WITNESS); SerializeToVector(os, non_witness_utxo); } if (!witness_utxo.IsNull()) { - SerializeToVector(s, PSBT_IN_WITNESS_UTXO); + SerializeToVector(s, CompactSizeWriter(PSBT_IN_WITNESS_UTXO)); SerializeToVector(s, witness_utxo); } if (final_script_sig.empty() && final_script_witness.IsNull()) { // Write any partial signatures for (auto sig_pair : partial_sigs) { - SerializeToVector(s, PSBT_IN_PARTIAL_SIG, Span{sig_pair.second.first}); + SerializeToVector(s, CompactSizeWriter(PSBT_IN_PARTIAL_SIG), Span{sig_pair.second.first}); s << sig_pair.second.second; } // Write the sighash type if (sighash_type != std::nullopt) { - SerializeToVector(s, PSBT_IN_SIGHASH); + SerializeToVector(s, CompactSizeWriter(PSBT_IN_SIGHASH)); SerializeToVector(s, *sighash_type); } // Write the redeem script if (!redeem_script.empty()) { - SerializeToVector(s, PSBT_IN_REDEEMSCRIPT); + SerializeToVector(s, CompactSizeWriter(PSBT_IN_REDEEMSCRIPT)); s << redeem_script; } // Write the witness script if (!witness_script.empty()) { - SerializeToVector(s, PSBT_IN_WITNESSSCRIPT); + SerializeToVector(s, CompactSizeWriter(PSBT_IN_WITNESSSCRIPT)); s << witness_script; } // Write any hd keypaths - SerializeHDKeypaths(s, hd_keypaths, PSBT_IN_BIP32_DERIVATION); + SerializeHDKeypaths(s, hd_keypaths, CompactSizeWriter(PSBT_IN_BIP32_DERIVATION)); } // Write script sig if (!final_script_sig.empty()) { - SerializeToVector(s, PSBT_IN_SCRIPTSIG); + SerializeToVector(s, CompactSizeWriter(PSBT_IN_SCRIPTSIG)); s << final_script_sig; } // write script witness if (!final_script_witness.IsNull()) { - SerializeToVector(s, PSBT_IN_SCRIPTWITNESS); + SerializeToVector(s, CompactSizeWriter(PSBT_IN_SCRIPTWITNESS)); SerializeToVector(s, final_script_witness.stack); } @@ -147,8 +147,9 @@ struct PSBTInput break; } - // First byte of key is the type - unsigned char type = key[0]; + // Type is compact size uint at beginning of key + SpanReader skey(s.GetType(), s.GetVersion(), key); + uint64_t type = ReadCompactSize(skey); // Do stuff based on type switch(type) { @@ -292,18 +293,18 @@ struct PSBTOutput inline void Serialize(Stream& s) const { // Write the redeem script if (!redeem_script.empty()) { - SerializeToVector(s, PSBT_OUT_REDEEMSCRIPT); + SerializeToVector(s, CompactSizeWriter(PSBT_OUT_REDEEMSCRIPT)); s << redeem_script; } // Write the witness script if (!witness_script.empty()) { - SerializeToVector(s, PSBT_OUT_WITNESSSCRIPT); + SerializeToVector(s, CompactSizeWriter(PSBT_OUT_WITNESSSCRIPT)); s << witness_script; } // Write any hd keypaths - SerializeHDKeypaths(s, hd_keypaths, PSBT_OUT_BIP32_DERIVATION); + SerializeHDKeypaths(s, hd_keypaths, CompactSizeWriter(PSBT_OUT_BIP32_DERIVATION)); // Write unknown things for (auto& entry : unknown) { @@ -334,8 +335,9 @@ struct PSBTOutput break; } - // First byte of key is the type - unsigned char type = key[0]; + // Type is compact size uint at beginning of key + SpanReader skey(s.GetType(), s.GetVersion(), key); + uint64_t type = ReadCompactSize(skey); // Do stuff based on type switch(type) { @@ -422,7 +424,7 @@ struct PartiallySignedTransaction s << PSBT_MAGIC_BYTES; // unsigned tx flag - SerializeToVector(s, PSBT_GLOBAL_UNSIGNED_TX); + SerializeToVector(s, CompactSizeWriter(PSBT_GLOBAL_UNSIGNED_TX)); // Write serialized tx to a stream OverrideStream<Stream> os(&s, s.GetType(), s.GetVersion() | SERIALIZE_TRANSACTION_NO_WITNESS); @@ -474,8 +476,9 @@ struct PartiallySignedTransaction break; } - // First byte of key is the type - unsigned char type = key[0]; + // Type is compact size uint at beginning of key + SpanReader skey(s.GetType(), s.GetVersion(), key); + uint64_t type = ReadCompactSize(skey); // Do stuff based on type switch(type) { diff --git a/src/script/sign.h b/src/script/sign.h index 50525af332..622147cd95 100644 --- a/src/script/sign.h +++ b/src/script/sign.h @@ -143,7 +143,7 @@ void DeserializeHDKeypaths(Stream& s, const std::vector<unsigned char>& key, std // Serialize HD keypaths to a stream from a map template<typename Stream> -void SerializeHDKeypaths(Stream& s, const std::map<CPubKey, KeyOriginInfo>& hd_keypaths, uint8_t type) +void SerializeHDKeypaths(Stream& s, const std::map<CPubKey, KeyOriginInfo>& hd_keypaths, CompactSizeWriter type) { for (auto keypath_pair : hd_keypaths) { if (!keypath_pair.first.IsValid()) { diff --git a/src/serialize.h b/src/serialize.h index edf10440c6..873361fe9e 100644 --- a/src/serialize.h +++ b/src/serialize.h @@ -527,6 +527,19 @@ struct CompactSizeFormatter } }; +class CompactSizeWriter +{ +protected: + uint64_t n; +public: + explicit CompactSizeWriter(uint64_t n_in) : n(n_in) { } + + template<typename Stream> + void Serialize(Stream &s) const { + WriteCompactSize<Stream>(s, n); + } +}; + template<size_t Limit> struct LimitedStringFormatter { diff --git a/src/test/fuzz/script_sign.cpp b/src/test/fuzz/script_sign.cpp index 79380bd9c9..205f6c8061 100644 --- a/src/test/fuzz/script_sign.cpp +++ b/src/test/fuzz/script_sign.cpp @@ -43,7 +43,7 @@ FUZZ_TARGET_INIT(script_sign, initialize_script_sign) } catch (const std::ios_base::failure&) { } CDataStream serialized{SER_NETWORK, PROTOCOL_VERSION}; - SerializeHDKeypaths(serialized, hd_keypaths, fuzzed_data_provider.ConsumeIntegral<uint8_t>()); + SerializeHDKeypaths(serialized, hd_keypaths, CompactSizeWriter(fuzzed_data_provider.ConsumeIntegral<uint8_t>())); } { @@ -61,7 +61,7 @@ FUZZ_TARGET_INIT(script_sign, initialize_script_sign) } CDataStream serialized{SER_NETWORK, PROTOCOL_VERSION}; try { - SerializeHDKeypaths(serialized, hd_keypaths, fuzzed_data_provider.ConsumeIntegral<uint8_t>()); + SerializeHDKeypaths(serialized, hd_keypaths, CompactSizeWriter(fuzzed_data_provider.ConsumeIntegral<uint8_t>())); } catch (const std::ios_base::failure&) { } std::map<CPubKey, KeyOriginInfo> deserialized_hd_keypaths; |