diff options
author | fanquake <fanquake@gmail.com> | 2023-09-07 11:27:36 +0100 |
---|---|---|
committer | fanquake <fanquake@gmail.com> | 2023-09-07 11:34:34 +0100 |
commit | 8e0d9796da8cfb6c4e918788a03eea125d0633a6 (patch) | |
tree | 4b6aee429ac616e1b419de8dc7b73f202a487a24 | |
parent | 5ad4eb33656cc6d3a56253d116b0b4e17fe9b712 (diff) | |
parent | fa626af3edbe8d98b2de91dd71729ceef90389fb (diff) |
Merge bitcoin/bitcoin#25284: net: Use serialization parameters for CAddress serialization
fa626af3edbe8d98b2de91dd71729ceef90389fb Remove unused legacy CHashVerifier (MarcoFalke)
fafa3fc5a62702da72991497e3270034eb9159c0 test: add tests that exercise WithParams() (MarcoFalke)
fac81affb527132945773a5315bd27fec61ec52f Use serialization parameters for CAddress serialization (MarcoFalke)
faec591d64e40ba7ec7656cbfdda1a05953bde13 Support for serialization parameters (MarcoFalke)
fac42e9d35f6ba046999b2e3a757ab720c51b6bb Rename CSerAction* to Action* (MarcoFalke)
aaaa3fa9477eef9ea72e4a501d130c57b47b470a Replace READWRITEAS macro with AsBase wrapping function (MarcoFalke)
Pull request description:
It seems confusing that picking a wrong value for `ADDRV2_FORMAT` could have effects on consensus. (See the docstring of `ADDRV2_FORMAT`).
Fix this by implementing https://github.com/bitcoin/bitcoin/issues/19477#issuecomment-1147421608 .
This may also help with libbitcoinkernel, see https://github.com/bitcoin/bitcoin/pull/28327
ACKs for top commit:
TheCharlatan:
ACK fa626af3edbe8d98b2de91dd71729ceef90389fb
ajtowns:
ACK fa626af3edbe8d98b2de91dd71729ceef90389fb
Tree-SHA512: 229d379da27308890de212b1fd2b85dac13f3f768413cb56a4b0c2da708f28344d04356ffd75bfcbaa4cabf0b6cc363c4f812a8f1648cff9e436811498278318
-rw-r--r-- | src/addrdb.cpp | 25 | ||||
-rw-r--r-- | src/addrdb.h | 7 | ||||
-rw-r--r-- | src/addrman.cpp | 25 | ||||
-rw-r--r-- | src/addrman_impl.h | 3 | ||||
-rw-r--r-- | src/hash.h | 39 | ||||
-rw-r--r-- | src/index/disktxpos.h | 3 | ||||
-rw-r--r-- | src/net.cpp | 3 | ||||
-rw-r--r-- | src/net_processing.cpp | 28 | ||||
-rw-r--r-- | src/netaddress.h | 25 | ||||
-rw-r--r-- | src/primitives/block.h | 3 | ||||
-rw-r--r-- | src/protocol.h | 36 | ||||
-rw-r--r-- | src/script/script.h | 2 | ||||
-rw-r--r-- | src/serialize.h | 206 | ||||
-rw-r--r-- | src/test/addrman_tests.cpp | 22 | ||||
-rw-r--r-- | src/test/fuzz/addrman.cpp | 16 | ||||
-rw-r--r-- | src/test/fuzz/deserialize.cpp | 151 | ||||
-rw-r--r-- | src/test/fuzz/net.cpp | 2 | ||||
-rw-r--r-- | src/test/fuzz/script_sign.cpp | 3 | ||||
-rw-r--r-- | src/test/fuzz/util.h | 21 | ||||
-rw-r--r-- | src/test/fuzz/util/net.cpp | 21 | ||||
-rw-r--r-- | src/test/net_tests.cpp | 81 | ||||
-rw-r--r-- | src/test/netbase_tests.cpp | 16 | ||||
-rw-r--r-- | src/test/serialize_tests.cpp | 143 | ||||
-rw-r--r-- | src/test/streams_tests.cpp | 4 | ||||
-rw-r--r-- | src/test/util/net.cpp | 4 | ||||
-rw-r--r-- | src/version.h | 2 |
26 files changed, 594 insertions, 297 deletions
diff --git a/src/addrdb.cpp b/src/addrdb.cpp index 0fcb5ed5c9..c5b474339b 100644 --- a/src/addrdb.cpp +++ b/src/addrdb.cpp @@ -47,16 +47,16 @@ bool SerializeDB(Stream& stream, const Data& data) } template <typename Data> -bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data& data, int version) +bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data& data) { // Generate random temporary filename const uint16_t randv{GetRand<uint16_t>()}; std::string tmpfn = strprintf("%s.%04x", prefix, randv); - // open temp output file, and associate with CAutoFile + // open temp output file fs::path pathTmp = gArgs.GetDataDirNet() / fs::u8path(tmpfn); FILE *file = fsbridge::fopen(pathTmp, "wb"); - CAutoFile fileout(file, SER_DISK, version); + AutoFile fileout{file}; if (fileout.IsNull()) { fileout.fclose(); remove(pathTmp); @@ -86,9 +86,9 @@ bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data } template <typename Stream, typename Data> -void DeserializeDB(Stream& stream, Data& data, bool fCheckSum = true) +void DeserializeDB(Stream& stream, Data&& data, bool fCheckSum = true) { - CHashVerifier<Stream> verifier(&stream); + HashVerifier verifier{stream}; // de-serialize file header (network specific magic number) and .. unsigned char pchMsgTmp[4]; verifier >> pchMsgTmp; @@ -111,11 +111,10 @@ void DeserializeDB(Stream& stream, Data& data, bool fCheckSum = true) } template <typename Data> -void DeserializeFileDB(const fs::path& path, Data& data, int version) +void DeserializeFileDB(const fs::path& path, Data&& data) { - // open input file, and associate with CAutoFile FILE* file = fsbridge::fopen(path, "rb"); - CAutoFile filein(file, SER_DISK, version); + AutoFile filein{file}; if (filein.IsNull()) { throw DbNotFoundError{}; } @@ -175,10 +174,10 @@ bool CBanDB::Read(banmap_t& banSet) bool DumpPeerAddresses(const ArgsManager& args, const AddrMan& addr) { const auto pathAddr = args.GetDataDirNet() / "peers.dat"; - return SerializeFileDB("peers", pathAddr, addr, CLIENT_VERSION); + return SerializeFileDB("peers", pathAddr, addr); } -void ReadFromStream(AddrMan& addr, CDataStream& ssPeers) +void ReadFromStream(AddrMan& addr, DataStream& ssPeers) { DeserializeDB(ssPeers, addr, false); } @@ -191,7 +190,7 @@ util::Result<std::unique_ptr<AddrMan>> LoadAddrman(const NetGroupManager& netgro const auto start{SteadyClock::now()}; const auto path_addr{args.GetDataDirNet() / "peers.dat"}; try { - DeserializeFileDB(path_addr, *addrman, CLIENT_VERSION); + DeserializeFileDB(path_addr, *addrman); LogPrintf("Loaded %i addresses from peers.dat %dms\n", addrman->Size(), Ticks<std::chrono::milliseconds>(SteadyClock::now() - start)); } catch (const DbNotFoundError&) { // Addrman can be in an inconsistent state after failure, reset it @@ -217,14 +216,14 @@ util::Result<std::unique_ptr<AddrMan>> LoadAddrman(const NetGroupManager& netgro void DumpAnchors(const fs::path& anchors_db_path, const std::vector<CAddress>& anchors) { LOG_TIME_SECONDS(strprintf("Flush %d outbound block-relay-only peer addresses to anchors.dat", anchors.size())); - SerializeFileDB("anchors", anchors_db_path, anchors, CLIENT_VERSION | ADDRV2_FORMAT); + SerializeFileDB("anchors", anchors_db_path, WithParams(CAddress::V2_DISK, anchors)); } std::vector<CAddress> ReadAnchors(const fs::path& anchors_db_path) { std::vector<CAddress> anchors; try { - DeserializeFileDB(anchors_db_path, anchors, CLIENT_VERSION | ADDRV2_FORMAT); + DeserializeFileDB(anchors_db_path, WithParams(CAddress::V2_DISK, anchors)); LogPrintf("Loaded %i addresses from %s\n", anchors.size(), fs::quoted(fs::PathToString(anchors_db_path.filename()))); } catch (const std::exception&) { anchors.clear(); diff --git a/src/addrdb.h b/src/addrdb.h index 0037495d18..cc3014dce2 100644 --- a/src/addrdb.h +++ b/src/addrdb.h @@ -16,12 +16,13 @@ class ArgsManager; class AddrMan; class CAddress; -class CDataStream; +class DataStream; class NetGroupManager; -bool DumpPeerAddresses(const ArgsManager& args, const AddrMan& addr); /** Only used by tests. */ -void ReadFromStream(AddrMan& addr, CDataStream& ssPeers); +void ReadFromStream(AddrMan& addr, DataStream& ssPeers); + +bool DumpPeerAddresses(const ArgsManager& args, const AddrMan& addr); /** Access to the banlist database (banlist.json) */ class CBanDB diff --git a/src/addrman.cpp b/src/addrman.cpp index 9ccf71774a..212baab9d4 100644 --- a/src/addrman.cpp +++ b/src/addrman.cpp @@ -171,8 +171,7 @@ void AddrManImpl::Serialize(Stream& s_) const */ // Always serialize in the latest version (FILE_FORMAT). - - OverrideStream<Stream> s(&s_, s_.GetType(), s_.GetVersion() | ADDRV2_FORMAT); + ParamsStream s{CAddress::V2_DISK, s_}; s << static_cast<uint8_t>(FILE_FORMAT); @@ -236,14 +235,8 @@ void AddrManImpl::Unserialize(Stream& s_) Format format; s_ >> Using<CustomUintFormatter<1>>(format); - int stream_version = s_.GetVersion(); - if (format >= Format::V3_BIP155) { - // Add ADDRV2_FORMAT to the version so that the CNetAddr and CAddress - // unserialize methods know that an address in addrv2 format is coming. - stream_version |= ADDRV2_FORMAT; - } - - OverrideStream<Stream> s(&s_, s_.GetType(), stream_version); + const auto ser_params = (format >= Format::V3_BIP155 ? CAddress::V2_DISK : CAddress::V1_DISK); + ParamsStream s{ser_params, s_}; uint8_t compat; s >> compat; @@ -1249,12 +1242,12 @@ void AddrMan::Unserialize(Stream& s_) } // explicit instantiation -template void AddrMan::Serialize(HashedSourceWriter<CAutoFile>& s) const; -template void AddrMan::Serialize(CDataStream& s) const; -template void AddrMan::Unserialize(CAutoFile& s); -template void AddrMan::Unserialize(CHashVerifier<CAutoFile>& s); -template void AddrMan::Unserialize(CDataStream& s); -template void AddrMan::Unserialize(CHashVerifier<CDataStream>& s); +template void AddrMan::Serialize(HashedSourceWriter<AutoFile>&) const; +template void AddrMan::Serialize(DataStream&) const; +template void AddrMan::Unserialize(AutoFile&); +template void AddrMan::Unserialize(HashVerifier<AutoFile>&); +template void AddrMan::Unserialize(DataStream&); +template void AddrMan::Unserialize(HashVerifier<DataStream>&); size_t AddrMan::Size(std::optional<Network> net, std::optional<bool> in_new) const { diff --git a/src/addrman_impl.h b/src/addrman_impl.h index 9aff408e34..1cfaca04a3 100644 --- a/src/addrman_impl.h +++ b/src/addrman_impl.h @@ -65,8 +65,7 @@ public: SERIALIZE_METHODS(AddrInfo, obj) { - READWRITEAS(CAddress, obj); - READWRITE(obj.source, Using<ChronoFormatter<int64_t>>(obj.m_last_success), obj.nAttempts); + READWRITE(AsBase<CAddress>(obj), obj.source, Using<ChronoFormatter<int64_t>>(obj.m_last_success), obj.nAttempts); } AddrInfo(const CAddress &addrIn, const CNetAddr &addrSource) : CAddress(addrIn), source(addrSource) diff --git a/src/hash.h b/src/hash.h index 89c6f0dab9..f2b627ff4f 100644 --- a/src/hash.h +++ b/src/hash.h @@ -199,53 +199,20 @@ public: } }; -template<typename Source> -class CHashVerifier : public CHashWriter -{ -private: - Source* source; - -public: - explicit CHashVerifier(Source* source_) : CHashWriter(source_->GetType(), source_->GetVersion()), source(source_) {} - - void read(Span<std::byte> dst) - { - source->read(dst); - this->write(dst); - } - - void ignore(size_t nSize) - { - std::byte data[1024]; - while (nSize > 0) { - size_t now = std::min<size_t>(nSize, 1024); - read({data, now}); - nSize -= now; - } - } - - template<typename T> - CHashVerifier<Source>& operator>>(T&& obj) - { - ::Unserialize(*this, obj); - return (*this); - } -}; - /** Writes data to an underlying source stream, while hashing the written data. */ template <typename Source> -class HashedSourceWriter : public CHashWriter +class HashedSourceWriter : public HashWriter { private: Source& m_source; public: - explicit HashedSourceWriter(Source& source LIFETIMEBOUND) : CHashWriter{source.GetType(), source.GetVersion()}, m_source{source} {} + explicit HashedSourceWriter(Source& source LIFETIMEBOUND) : HashWriter{}, m_source{source} {} void write(Span<const std::byte> src) { m_source.write(src); - CHashWriter::write(src); + HashWriter::write(src); } template <typename T> diff --git a/src/index/disktxpos.h b/src/index/disktxpos.h index 7718755b78..1004f7ae87 100644 --- a/src/index/disktxpos.h +++ b/src/index/disktxpos.h @@ -14,8 +14,7 @@ struct CDiskTxPos : public FlatFilePos SERIALIZE_METHODS(CDiskTxPos, obj) { - READWRITEAS(FlatFilePos, obj); - READWRITE(VARINT(obj.nTxOffset)); + READWRITE(AsBase<FlatFilePos>(obj), VARINT(obj.nTxOffset)); } CDiskTxPos(const FlatFilePos &blockIn, unsigned int nTxOffsetIn) : FlatFilePos(blockIn.nFile, blockIn.nPos), nTxOffset(nTxOffsetIn) { diff --git a/src/net.cpp b/src/net.cpp index e66c0ec7f8..4addca0982 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -202,7 +202,8 @@ static std::vector<CAddress> ConvertSeeds(const std::vector<uint8_t> &vSeedsIn) const auto one_week{7 * 24h}; std::vector<CAddress> vSeedsOut; FastRandomContext rng; - CDataStream s(vSeedsIn, SER_NETWORK, PROTOCOL_VERSION | ADDRV2_FORMAT); + DataStream underlying_stream{vSeedsIn}; + ParamsStream s{CAddress::V2_NETWORK, underlying_stream}; while (!s.eof()) { CService endpoint; s >> endpoint; diff --git a/src/net_processing.cpp b/src/net_processing.cpp index c5a22f258a..6b415b3a1e 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -1415,8 +1415,8 @@ void PeerManagerImpl::PushNodeVersion(CNode& pnode, const Peer& peer) const bool tx_relay{!RejectIncomingTxs(pnode)}; m_connman.PushMessage(&pnode, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERSION, PROTOCOL_VERSION, my_services, nTime, - your_services, addr_you, // Together the pre-version-31402 serialization of CAddress "addrYou" (without nTime) - my_services, CService(), // Together the pre-version-31402 serialization of CAddress "addrMe" (without nTime) + your_services, WithParams(CNetAddr::V1, addr_you), // Together the pre-version-31402 serialization of CAddress "addrYou" (without nTime) + my_services, WithParams(CNetAddr::V1, CService{}), // Together the pre-version-31402 serialization of CAddress "addrMe" (without nTime) nonce, strSubVersion, nNodeStartingHeight, tx_relay)); if (fLogIPs) { @@ -3293,7 +3293,7 @@ void PeerManagerImpl::ProcessMessage(CNode& pfrom, const std::string& msg_type, nTime = 0; } vRecv.ignore(8); // Ignore the addrMe service bits sent by the peer - vRecv >> addrMe; + vRecv >> WithParams(CNetAddr::V1, addrMe); if (!pfrom.IsInboundConn()) { m_addrman.SetServices(pfrom.addr, nServices); @@ -3672,17 +3672,17 @@ void PeerManagerImpl::ProcessMessage(CNode& pfrom, const std::string& msg_type, } if (msg_type == NetMsgType::ADDR || msg_type == NetMsgType::ADDRV2) { - int stream_version = vRecv.GetVersion(); - if (msg_type == NetMsgType::ADDRV2) { - // Add ADDRV2_FORMAT to the version so that the CNetAddr and CAddress + const auto ser_params{ + msg_type == NetMsgType::ADDRV2 ? + // Set V2 param so that the CNetAddr and CAddress // unserialize methods know that an address in v2 format is coming. - stream_version |= ADDRV2_FORMAT; - } + CAddress::V2_NETWORK : + CAddress::V1_NETWORK, + }; - OverrideStream<CDataStream> s(&vRecv, vRecv.GetType(), stream_version); std::vector<CAddress> vAddr; - s >> vAddr; + vRecv >> WithParams(ser_params, vAddr); if (!SetupAddressRelay(pfrom, *peer)) { LogPrint(BCLog::NET, "ignoring %s message from %s peer=%d\n", msg_type, pfrom.ConnectionTypeAsString(), pfrom.GetId()); @@ -5289,15 +5289,15 @@ void PeerManagerImpl::MaybeSendAddr(CNode& node, Peer& peer, std::chrono::micros if (peer.m_addrs_to_send.empty()) return; const char* msg_type; - int make_flags; + CNetAddr::Encoding ser_enc; if (peer.m_wants_addrv2) { msg_type = NetMsgType::ADDRV2; - make_flags = ADDRV2_FORMAT; + ser_enc = CNetAddr::Encoding::V2; } else { msg_type = NetMsgType::ADDR; - make_flags = 0; + ser_enc = CNetAddr::Encoding::V1; } - m_connman.PushMessage(&node, CNetMsgMaker(node.GetCommonVersion()).Make(make_flags, msg_type, peer.m_addrs_to_send)); + m_connman.PushMessage(&node, CNetMsgMaker(node.GetCommonVersion()).Make(msg_type, WithParams(CAddress::SerParams{{ser_enc}, CAddress::Format::Network}, peer.m_addrs_to_send))); peer.m_addrs_to_send.clear(); // we only send the big addr message once diff --git a/src/netaddress.h b/src/netaddress.h index 7cba6c00d0..a0944c886f 100644 --- a/src/netaddress.h +++ b/src/netaddress.h @@ -25,14 +25,6 @@ #include <vector> /** - * A flag that is ORed into the protocol version to designate that addresses - * should be serialized in (unserialized from) v2 format (BIP155). - * Make sure that this does not collide with any of the values in `version.h` - * or with `SERIALIZE_TRANSACTION_NO_WITNESS`. - */ -static constexpr int ADDRV2_FORMAT = 0x20000000; - -/** * A network type. * @note An address may belong to more than one network, for example `10.0.0.1` * belongs to both `NET_UNROUTABLE` and `NET_IPV4`. @@ -220,13 +212,23 @@ public: return IsIPv4() || IsIPv6() || IsTor() || IsI2P() || IsCJDNS(); } + enum class Encoding { + V1, + V2, //!< BIP155 encoding + }; + struct SerParams { + const Encoding enc; + }; + static constexpr SerParams V1{Encoding::V1}; + static constexpr SerParams V2{Encoding::V2}; + /** * Serialize to a stream. */ template <typename Stream> void Serialize(Stream& s) const { - if (s.GetVersion() & ADDRV2_FORMAT) { + if (s.GetParams().enc == Encoding::V2) { SerializeV2Stream(s); } else { SerializeV1Stream(s); @@ -239,7 +241,7 @@ public: template <typename Stream> void Unserialize(Stream& s) { - if (s.GetVersion() & ADDRV2_FORMAT) { + if (s.GetParams().enc == Encoding::V2) { UnserializeV2Stream(s); } else { UnserializeV1Stream(s); @@ -540,8 +542,7 @@ public: SERIALIZE_METHODS(CService, obj) { - READWRITEAS(CNetAddr, obj); - READWRITE(Using<BigEndianFormatter<2>>(obj.port)); + READWRITE(AsBase<CNetAddr>(obj), Using<BigEndianFormatter<2>>(obj.port)); } friend class CServiceHash; diff --git a/src/primitives/block.h b/src/primitives/block.h index bd11279a6e..861d362414 100644 --- a/src/primitives/block.h +++ b/src/primitives/block.h @@ -87,8 +87,7 @@ public: SERIALIZE_METHODS(CBlock, obj) { - READWRITEAS(CBlockHeader, obj); - READWRITE(obj.vtx); + READWRITE(AsBase<CBlockHeader>(obj), obj.vtx); } void SetNull() diff --git a/src/protocol.h b/src/protocol.h index ac4545c311..a7ca0c6f3e 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -390,35 +390,43 @@ public: CAddress(CService ipIn, ServiceFlags nServicesIn) : CService{ipIn}, nServices{nServicesIn} {}; CAddress(CService ipIn, ServiceFlags nServicesIn, NodeSeconds time) : CService{ipIn}, nTime{time}, nServices{nServicesIn} {}; - SERIALIZE_METHODS(CAddress, obj) + enum class Format { + Disk, + Network, + }; + struct SerParams : CNetAddr::SerParams { + const Format fmt; + }; + static constexpr SerParams V1_NETWORK{{CNetAddr::Encoding::V1}, Format::Network}; + static constexpr SerParams V2_NETWORK{{CNetAddr::Encoding::V2}, Format::Network}; + static constexpr SerParams V1_DISK{{CNetAddr::Encoding::V1}, Format::Disk}; + static constexpr SerParams V2_DISK{{CNetAddr::Encoding::V2}, Format::Disk}; + + SERIALIZE_METHODS_PARAMS(CAddress, obj, SerParams, params) { - // CAddress has a distinct network serialization and a disk serialization, but it should never - // be hashed (except through CHashWriter in addrdb.cpp, which sets SER_DISK), and it's - // ambiguous what that would mean. Make sure no code relying on that is introduced: - assert(!(s.GetType() & SER_GETHASH)); bool use_v2; - if (s.GetType() & SER_DISK) { + if (params.fmt == Format::Disk) { // In the disk serialization format, the encoding (v1 or v2) is determined by a flag version // that's part of the serialization itself. ADDRV2_FORMAT in the stream version only determines // whether V2 is chosen/permitted at all. uint32_t stored_format_version = DISK_VERSION_INIT; - if (s.GetVersion() & ADDRV2_FORMAT) stored_format_version |= DISK_VERSION_ADDRV2; + if (params.enc == Encoding::V2) stored_format_version |= DISK_VERSION_ADDRV2; READWRITE(stored_format_version); stored_format_version &= ~DISK_VERSION_IGNORE_MASK; // ignore low bits if (stored_format_version == 0) { use_v2 = false; - } else if (stored_format_version == DISK_VERSION_ADDRV2 && (s.GetVersion() & ADDRV2_FORMAT)) { - // Only support v2 deserialization if ADDRV2_FORMAT is set. + } else if (stored_format_version == DISK_VERSION_ADDRV2 && params.enc == Encoding::V2) { + // Only support v2 deserialization if V2 is set. use_v2 = true; } else { throw std::ios_base::failure("Unsupported CAddress disk format version"); } } else { + assert(params.fmt == Format::Network); // In the network serialization format, the encoding (v1 or v2) is determined directly by - // the value of ADDRV2_FORMAT in the stream version, as no explicitly encoded version + // the value of enc in the stream params, as no explicitly encoded version // exists in the stream. - assert(s.GetType() & SER_NETWORK); - use_v2 = s.GetVersion() & ADDRV2_FORMAT; + use_v2 = params.enc == Encoding::V2; } READWRITE(Using<LossyChronoFormatter<uint32_t>>(obj.nTime)); @@ -432,8 +440,8 @@ public: READWRITE(Using<CustomUintFormatter<8>>(obj.nServices)); } // Invoke V1/V2 serializer for CService parent object. - OverrideStream<Stream> os(&s, s.GetType(), use_v2 ? ADDRV2_FORMAT : 0); - SerReadWriteMany(os, ser_action, ReadWriteAsHelper<CService>(obj)); + const auto ser_params{use_v2 ? CNetAddr::V2 : CNetAddr::V1}; + READWRITE(WithParams(ser_params, AsBase<CService>(obj))); } //! Always included in serialization. The behavior is unspecified if the value is not representable as uint32_t. diff --git a/src/script/script.h b/src/script/script.h index 902f756afc..c329a2afd6 100644 --- a/src/script/script.h +++ b/src/script/script.h @@ -434,7 +434,7 @@ public: CScript(std::vector<unsigned char>::const_iterator pbegin, std::vector<unsigned char>::const_iterator pend) : CScriptBase(pbegin, pend) { } CScript(const unsigned char* pbegin, const unsigned char* pend) : CScriptBase(pbegin, pend) { } - SERIALIZE_METHODS(CScript, obj) { READWRITEAS(CScriptBase, obj); } + SERIALIZE_METHODS(CScript, obj) { READWRITE(AsBase<CScriptBase>(obj)); } explicit CScript(int64_t b) { operator<<(b); } explicit CScript(opcodetype b) { operator<<(b); } diff --git a/src/serialize.h b/src/serialize.h index 39f2c0f3ae..2d790190a0 100644 --- a/src/serialize.h +++ b/src/serialize.h @@ -6,6 +6,7 @@ #ifndef BITCOIN_SERIALIZE_H #define BITCOIN_SERIALIZE_H +#include <attributes.h> #include <compat/endian.h> #include <algorithm> @@ -133,12 +134,40 @@ enum SER_GETHASH = (1 << 2), }; -//! Convert the reference base type to X, without changing constness or reference type. -template<typename X> X& ReadWriteAsHelper(X& x) { return x; } -template<typename X> const X& ReadWriteAsHelper(const X& x) { return x; } +/** + * Convert any argument to a reference to X, maintaining constness. + * + * This can be used in serialization code to invoke a base class's + * serialization routines. + * + * Example use: + * class Base { ... }; + * class Child : public Base { + * int m_data; + * public: + * SERIALIZE_METHODS(Child, obj) { + * READWRITE(AsBase<Base>(obj), obj.m_data); + * } + * }; + * + * static_cast cannot easily be used here, as the type of Obj will be const Child& + * during serialization and Child& during deserialization. AsBase will convert to + * const Base& and Base& appropriately. + */ +template <class Out, class In> +Out& AsBase(In& x) +{ + static_assert(std::is_base_of_v<Out, In>); + return x; +} +template <class Out, class In> +const Out& AsBase(const In& x) +{ + static_assert(std::is_base_of_v<Out, In>); + return x; +} #define READWRITE(...) (::SerReadWriteMany(s, ser_action, __VA_ARGS__)) -#define READWRITEAS(type, obj) (::SerReadWriteMany(s, ser_action, ReadWriteAsHelper<type>(obj))) #define SER_READ(obj, code) ::SerRead(s, ser_action, obj, [&](Stream& s, typename std::remove_const<Type>::type& obj) { code; }) #define SER_WRITE(obj, code) ::SerWrite(s, ser_action, obj, [&](Stream& s, const Type& obj) { code; }) @@ -160,11 +189,66 @@ template<typename X> const X& ReadWriteAsHelper(const X& x) { return x; } */ #define FORMATTER_METHODS(cls, obj) \ template<typename Stream> \ - static void Ser(Stream& s, const cls& obj) { SerializationOps(obj, s, CSerActionSerialize()); } \ + static void Ser(Stream& s, const cls& obj) { SerializationOps(obj, s, ActionSerialize{}); } \ template<typename Stream> \ - static void Unser(Stream& s, cls& obj) { SerializationOps(obj, s, CSerActionUnserialize()); } \ + static void Unser(Stream& s, cls& obj) { SerializationOps(obj, s, ActionUnserialize{}); } \ template<typename Stream, typename Type, typename Operation> \ - static inline void SerializationOps(Type& obj, Stream& s, Operation ser_action) \ + static void SerializationOps(Type& obj, Stream& s, Operation ser_action) + +/** + * Variant of FORMATTER_METHODS that supports a declared parameter type. + * + * If a formatter has a declared parameter type, it must be invoked directly or + * indirectly with a parameter of that type. This permits making serialization + * depend on run-time context in a type-safe way. + * + * Example use: + * struct BarParameter { bool fancy; ... }; + * struct Bar { ... }; + * struct FooFormatter { + * FORMATTER_METHODS(Bar, obj, BarParameter, param) { + * if (param.fancy) { + * READWRITE(VARINT(obj.value)); + * } else { + * READWRITE(obj.value); + * } + * } + * }; + * which would then be invoked as + * READWRITE(WithParams(BarParameter{...}, Using<FooFormatter>(obj.foo))) + * + * WithParams(parameter, obj) can be invoked anywhere in the call stack; it is + * passed down recursively into all serialization code, until another + * WithParams overrides it. + * + * Parameters will be implicitly converted where appropriate. This means that + * "parent" serialization code can use a parameter that derives from, or is + * convertible to, a "child" formatter's parameter type. + * + * Compilation will fail in any context where serialization is invoked but + * no parameter of a type convertible to BarParameter is provided. + */ +#define FORMATTER_METHODS_PARAMS(cls, obj, paramcls, paramobj) \ + template <typename Stream> \ + static void Ser(Stream& s, const cls& obj) { SerializationOps(obj, s, ActionSerialize{}, s.GetParams()); } \ + template <typename Stream> \ + static void Unser(Stream& s, cls& obj) { SerializationOps(obj, s, ActionUnserialize{}, s.GetParams()); } \ + template <typename Stream, typename Type, typename Operation> \ + static void SerializationOps(Type& obj, Stream& s, Operation ser_action, const paramcls& paramobj) + +#define BASE_SERIALIZE_METHODS(cls) \ + template <typename Stream> \ + void Serialize(Stream& s) const \ + { \ + static_assert(std::is_same<const cls&, decltype(*this)>::value, "Serialize type mismatch"); \ + Ser(s, *this); \ + } \ + template <typename Stream> \ + void Unserialize(Stream& s) \ + { \ + static_assert(std::is_same<cls&, decltype(*this)>::value, "Unserialize type mismatch"); \ + Unser(s, *this); \ + } /** * Implement the Serialize and Unserialize methods by delegating to a single templated @@ -173,21 +257,19 @@ template<typename X> const X& ReadWriteAsHelper(const X& x) { return x; } * thus allows a single implementation that sees the object as const for serializing * and non-const for deserializing, without casts. */ -#define SERIALIZE_METHODS(cls, obj) \ - template<typename Stream> \ - void Serialize(Stream& s) const \ - { \ - static_assert(std::is_same<const cls&, decltype(*this)>::value, "Serialize type mismatch"); \ - Ser(s, *this); \ - } \ - template<typename Stream> \ - void Unserialize(Stream& s) \ - { \ - static_assert(std::is_same<cls&, decltype(*this)>::value, "Unserialize type mismatch"); \ - Unser(s, *this); \ - } \ +#define SERIALIZE_METHODS(cls, obj) \ + BASE_SERIALIZE_METHODS(cls) \ FORMATTER_METHODS(cls, obj) +/** + * Variant of SERIALIZE_METHODS that supports a declared parameter type. + * + * See FORMATTER_METHODS_PARAMS for more information on parameters. + */ +#define SERIALIZE_METHODS_PARAMS(cls, obj, paramcls, paramobj) \ + BASE_SERIALIZE_METHODS(cls) \ + FORMATTER_METHODS_PARAMS(cls, obj, paramcls, paramobj) + // clang-format off #ifndef CHAR_EQUALS_INT8 template <typename Stream> void Serialize(Stream&, char) = delete; // char serialization forbidden. Use uint8_t or int8_t @@ -925,26 +1007,17 @@ void Unserialize(Stream& is, std::shared_ptr<const T>& p) } - /** - * Support for SERIALIZE_METHODS and READWRITE macro. + * Support for all macros providing or using the ser_action parameter of the SerializationOps method. */ -struct CSerActionSerialize -{ +struct ActionSerialize { constexpr bool ForRead() const { return false; } }; -struct CSerActionUnserialize -{ +struct ActionUnserialize { constexpr bool ForRead() const { return true; } }; - - - - - - /* ::GetSerializeSize implementations * * Computing the serialized size of objects is done through a special stream @@ -1003,36 +1076,36 @@ inline void UnserializeMany(Stream& s, Args&&... args) } template<typename Stream, typename... Args> -inline void SerReadWriteMany(Stream& s, CSerActionSerialize ser_action, const Args&... args) +inline void SerReadWriteMany(Stream& s, ActionSerialize ser_action, const Args&... args) { ::SerializeMany(s, args...); } template<typename Stream, typename... Args> -inline void SerReadWriteMany(Stream& s, CSerActionUnserialize ser_action, Args&&... args) +inline void SerReadWriteMany(Stream& s, ActionUnserialize ser_action, Args&&... args) { ::UnserializeMany(s, args...); } template<typename Stream, typename Type, typename Fn> -inline void SerRead(Stream& s, CSerActionSerialize ser_action, Type&&, Fn&&) +inline void SerRead(Stream& s, ActionSerialize ser_action, Type&&, Fn&&) { } template<typename Stream, typename Type, typename Fn> -inline void SerRead(Stream& s, CSerActionUnserialize ser_action, Type&& obj, Fn&& fn) +inline void SerRead(Stream& s, ActionUnserialize ser_action, Type&& obj, Fn&& fn) { fn(s, std::forward<Type>(obj)); } template<typename Stream, typename Type, typename Fn> -inline void SerWrite(Stream& s, CSerActionSerialize ser_action, Type&& obj, Fn&& fn) +inline void SerWrite(Stream& s, ActionSerialize ser_action, Type&& obj, Fn&& fn) { fn(s, std::forward<Type>(obj)); } template<typename Stream, typename Type, typename Fn> -inline void SerWrite(Stream& s, CSerActionUnserialize ser_action, Type&&, Fn&&) +inline void SerWrite(Stream& s, ActionUnserialize ser_action, Type&&, Fn&&) { } @@ -1061,4 +1134,61 @@ size_t GetSerializeSizeMany(int nVersion, const T&... t) return sc.size(); } +/** Wrapper that overrides the GetParams() function of a stream (and hides GetVersion/GetType). */ +template <typename Params, typename SubStream> +class ParamsStream +{ + const Params& m_params; + SubStream& m_substream; // private to avoid leaking version/type into serialization code that shouldn't see it + +public: + ParamsStream(const Params& params LIFETIMEBOUND, SubStream& substream LIFETIMEBOUND) : m_params{params}, m_substream{substream} {} + template <typename U> ParamsStream& operator<<(const U& obj) { ::Serialize(*this, obj); return *this; } + template <typename U> ParamsStream& operator>>(U&& obj) { ::Unserialize(*this, obj); return *this; } + void write(Span<const std::byte> src) { m_substream.write(src); } + void read(Span<std::byte> dst) { m_substream.read(dst); } + void ignore(size_t num) { m_substream.ignore(num); } + bool eof() const { return m_substream.eof(); } + size_t size() const { return m_substream.size(); } + const Params& GetParams() const { return m_params; } + int GetVersion() = delete; // Deprecated with Params usage + int GetType() = delete; // Deprecated with Params usage +}; + +/** Wrapper that serializes objects with the specified parameters. */ +template <typename Params, typename T> +class ParamsWrapper +{ + static_assert(std::is_lvalue_reference<T>::value, "ParamsWrapper needs an lvalue reference type T"); + const Params& m_params; + T m_object; + +public: + explicit ParamsWrapper(const Params& params, T obj) : m_params{params}, m_object{obj} {} + + template <typename Stream> + void Serialize(Stream& s) const + { + ParamsStream ss{m_params, s}; + ::Serialize(ss, m_object); + } + template <typename Stream> + void Unserialize(Stream& s) + { + ParamsStream ss{m_params, s}; + ::Unserialize(ss, m_object); + } +}; + +/** + * Return a wrapper around t that (de)serializes it with specified parameter params. + * + * See FORMATTER_METHODS_PARAMS for more information on serialization parameters. + */ +template <typename Params, typename T> +static auto WithParams(const Params& params, T&& t) +{ + return ParamsWrapper<Params, T&>{params, t}; +} + #endif // BITCOIN_SERIALIZE_H diff --git a/src/test/addrman_tests.cpp b/src/test/addrman_tests.cpp index 329b89554d..941018a820 100644 --- a/src/test/addrman_tests.cpp +++ b/src/test/addrman_tests.cpp @@ -697,7 +697,7 @@ BOOST_AUTO_TEST_CASE(addrman_serialization) auto addrman_asmap1_dup = std::make_unique<AddrMan>(netgroupman, DETERMINISTIC, ratio); auto addrman_noasmap = std::make_unique<AddrMan>(EMPTY_NETGROUPMAN, DETERMINISTIC, ratio); - CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + DataStream stream{}; CAddress addr = CAddress(ResolveService("250.1.1.1"), NODE_NONE); CNetAddr default_source; @@ -757,7 +757,7 @@ BOOST_AUTO_TEST_CASE(remove_invalid) // Confirm that invalid addresses are ignored in unserialization. auto addrman = std::make_unique<AddrMan>(EMPTY_NETGROUPMAN, DETERMINISTIC, GetCheckRatio(m_node)); - CDataStream stream(SER_NETWORK, PROTOCOL_VERSION); + DataStream stream{}; const CAddress new1{ResolveService("5.5.5.5"), NODE_NONE}; const CAddress new2{ResolveService("6.6.6.6"), NODE_NONE}; @@ -940,9 +940,9 @@ BOOST_AUTO_TEST_CASE(addrman_evictionworks) BOOST_CHECK(!addr_pos36.tried); } -static CDataStream AddrmanToStream(const AddrMan& addrman) +static auto AddrmanToStream(const AddrMan& addrman) { - CDataStream ssPeersIn(SER_DISK, CLIENT_VERSION); + DataStream ssPeersIn{}; ssPeersIn << Params().MessageStart(); ssPeersIn << addrman; return ssPeersIn; @@ -972,7 +972,7 @@ BOOST_AUTO_TEST_CASE(load_addrman) BOOST_CHECK(addrman.Size() == 3); // Test that the de-serialization does not throw an exception. - CDataStream ssPeers1 = AddrmanToStream(addrman); + auto ssPeers1{AddrmanToStream(addrman)}; bool exceptionThrown = false; AddrMan addrman1{EMPTY_NETGROUPMAN, !DETERMINISTIC, GetCheckRatio(m_node)}; @@ -989,7 +989,7 @@ BOOST_AUTO_TEST_CASE(load_addrman) BOOST_CHECK(exceptionThrown == false); // Test that ReadFromStream creates an addrman with the correct number of addrs. - CDataStream ssPeers2 = AddrmanToStream(addrman); + DataStream ssPeers2 = AddrmanToStream(addrman); AddrMan addrman2{EMPTY_NETGROUPMAN, !DETERMINISTIC, GetCheckRatio(m_node)}; BOOST_CHECK(addrman2.Size() == 0); @@ -998,9 +998,9 @@ BOOST_AUTO_TEST_CASE(load_addrman) } // Produce a corrupt peers.dat that claims 20 addrs when it only has one addr. -static CDataStream MakeCorruptPeersDat() +static auto MakeCorruptPeersDat() { - CDataStream s(SER_DISK, CLIENT_VERSION); + DataStream s{}; s << ::Params().MessageStart(); unsigned char nVersion = 1; @@ -1019,7 +1019,7 @@ static CDataStream MakeCorruptPeersDat() std::optional<CNetAddr> resolved{LookupHost("252.2.2.2", false)}; BOOST_REQUIRE(resolved.has_value()); AddrInfo info = AddrInfo(addr, resolved.value()); - s << info; + s << WithParams(CAddress::V1_DISK, info); return s; } @@ -1027,7 +1027,7 @@ static CDataStream MakeCorruptPeersDat() BOOST_AUTO_TEST_CASE(load_addrman_corrupted) { // Test that the de-serialization of corrupted peers.dat throws an exception. - CDataStream ssPeers1 = MakeCorruptPeersDat(); + auto ssPeers1{MakeCorruptPeersDat()}; bool exceptionThrown = false; AddrMan addrman1{EMPTY_NETGROUPMAN, !DETERMINISTIC, GetCheckRatio(m_node)}; BOOST_CHECK(addrman1.Size() == 0); @@ -1041,7 +1041,7 @@ BOOST_AUTO_TEST_CASE(load_addrman_corrupted) BOOST_CHECK(exceptionThrown); // Test that ReadFromStream fails if peers.dat is corrupt - CDataStream ssPeers2 = MakeCorruptPeersDat(); + auto ssPeers2{MakeCorruptPeersDat()}; AddrMan addrman2{EMPTY_NETGROUPMAN, !DETERMINISTIC, GetCheckRatio(m_node)}; BOOST_CHECK(addrman2.Size() == 0); diff --git a/src/test/fuzz/addrman.cpp b/src/test/fuzz/addrman.cpp index 02df4590de..9611a872ec 100644 --- a/src/test/fuzz/addrman.cpp +++ b/src/test/fuzz/addrman.cpp @@ -49,7 +49,7 @@ void initialize_addrman() FUZZ_TARGET(data_stream_addr_man, .init = initialize_addrman) { FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; - CDataStream data_stream = ConsumeDataStream(fuzzed_data_provider); + DataStream data_stream = ConsumeDataStream(fuzzed_data_provider); NetGroupManager netgroupman{ConsumeNetGroupManager(fuzzed_data_provider)}; AddrMan addr_man(netgroupman, /*deterministic=*/false, GetCheckRatio()); try { @@ -78,12 +78,12 @@ CNetAddr RandAddr(FuzzedDataProvider& fuzzed_data_provider, FastRandomContext& f net = 6; } - CDataStream s(SER_NETWORK, PROTOCOL_VERSION | ADDRV2_FORMAT); + DataStream s{}; s << net; s << fast_random_context.randbytes(net_len_map.at(net)); - s >> addr; + s >> WithParams(CAddress::V2_NETWORK, addr); } // Return a dummy IPv4 5.5.5.5 if we generated an invalid address. @@ -241,9 +241,7 @@ FUZZ_TARGET(addrman, .init = initialize_addrman) auto addr_man_ptr = std::make_unique<AddrManDeterministic>(netgroupman, fuzzed_data_provider); if (fuzzed_data_provider.ConsumeBool()) { const std::vector<uint8_t> serialized_data{ConsumeRandomLengthByteVector(fuzzed_data_provider)}; - CDataStream ds(serialized_data, SER_DISK, INIT_PROTO_VERSION); - const auto ser_version{fuzzed_data_provider.ConsumeIntegral<int32_t>()}; - ds.SetVersion(ser_version); + DataStream ds{serialized_data}; try { ds >> *addr_man_ptr; } catch (const std::ios_base::failure&) { @@ -295,7 +293,7 @@ FUZZ_TARGET(addrman, .init = initialize_addrman) in_new = fuzzed_data_provider.ConsumeBool(); } (void)const_addr_man.Size(network, in_new); - CDataStream data_stream(SER_NETWORK, PROTOCOL_VERSION); + DataStream data_stream{}; data_stream << const_addr_man; } @@ -309,10 +307,10 @@ FUZZ_TARGET(addrman_serdeser, .init = initialize_addrman) AddrManDeterministic addr_man1{netgroupman, fuzzed_data_provider}; AddrManDeterministic addr_man2{netgroupman, fuzzed_data_provider}; - CDataStream data_stream(SER_NETWORK, PROTOCOL_VERSION); + DataStream data_stream{}; FillAddrman(addr_man1, fuzzed_data_provider); data_stream << addr_man1; data_stream >> addr_man2; assert(addr_man1 == addr_man2); -}
\ No newline at end of file +} diff --git a/src/test/fuzz/deserialize.cpp b/src/test/fuzz/deserialize.cpp index 09402233bd..100a6b4ee4 100644 --- a/src/test/fuzz/deserialize.cpp +++ b/src/test/fuzz/deserialize.cpp @@ -24,6 +24,8 @@ #include <pubkey.h> #include <script/keyorigin.h> #include <streams.h> +#include <test/fuzz/fuzz.h> +#include <test/fuzz/util.h> #include <test/util/setup_common.h> #include <undo.h> #include <version.h> @@ -34,8 +36,6 @@ #include <stdint.h> #include <unistd.h> -#include <test/fuzz/fuzz.h> - using node::SnapshotMetadata; namespace { @@ -62,6 +62,34 @@ namespace { struct invalid_fuzzing_input_exception : public std::exception { }; +template <typename T, typename P> +DataStream Serialize(const T& obj, const P& params) +{ + DataStream ds{}; + ds << WithParams(params, obj); + return ds; +} + +template <typename T, typename P> +T Deserialize(DataStream&& ds, const P& params) +{ + T obj; + ds >> WithParams(params, obj); + return obj; +} + +template <typename T, typename P> +void DeserializeFromFuzzingInput(FuzzBufferType buffer, T&& obj, const P& params) +{ + DataStream ds{buffer}; + try { + ds >> WithParams(params, obj); + } catch (const std::ios_base::failure&) { + throw invalid_fuzzing_input_exception(); + } + assert(buffer.empty() || !Serialize(obj, params).empty()); +} + template <typename T> CDataStream Serialize(const T& obj, const int version = INIT_PROTO_VERSION, const int ser_type = SER_NETWORK) { @@ -79,7 +107,7 @@ T Deserialize(CDataStream ds) } template <typename T> -void DeserializeFromFuzzingInput(FuzzBufferType buffer, T& obj, const std::optional<int> protocol_version = std::nullopt, const int ser_type = SER_NETWORK) +void DeserializeFromFuzzingInput(FuzzBufferType buffer, T&& obj, const std::optional<int> protocol_version = std::nullopt, const int ser_type = SER_NETWORK) { CDataStream ds(buffer, ser_type, INIT_PROTO_VERSION); if (protocol_version) { @@ -101,6 +129,11 @@ void DeserializeFromFuzzingInput(FuzzBufferType buffer, T& obj, const std::optio assert(buffer.empty() || !Serialize(obj).empty()); } +template <typename T, typename P> +void AssertEqualAfterSerializeDeserialize(const T& obj, const P& params) +{ + assert(Deserialize<T>(Serialize(obj, params), params) == obj); +} template <typename T> void AssertEqualAfterSerializeDeserialize(const T& obj, const int version = INIT_PROTO_VERSION, const int ser_type = SER_NETWORK) { @@ -113,10 +146,11 @@ FUZZ_TARGET_DESERIALIZE(block_filter_deserialize, { BlockFilter block_filter; DeserializeFromFuzzingInput(buffer, block_filter); }) -FUZZ_TARGET_DESERIALIZE(addr_info_deserialize, { - AddrInfo addr_info; - DeserializeFromFuzzingInput(buffer, addr_info); -}) +FUZZ_TARGET(addr_info_deserialize, .init = initialize_deserialize) +{ + FuzzedDataProvider fdp{buffer.data(), buffer.size()}; + (void)ConsumeDeserializable<AddrInfo>(fdp, ConsumeDeserializationParams<CAddress::SerParams>(fdp)); +} FUZZ_TARGET_DESERIALIZE(block_file_info_deserialize, { CBlockFileInfo block_file_info; DeserializeFromFuzzingInput(buffer, block_file_info); @@ -197,13 +231,6 @@ FUZZ_TARGET_DESERIALIZE(blockmerkleroot, { bool mutated; BlockMerkleRoot(block, &mutated); }) -FUZZ_TARGET_DESERIALIZE(addrman_deserialize, { - NetGroupManager netgroupman{std::vector<bool>()}; - AddrMan am(netgroupman, - /*deterministic=*/false, - g_setup->m_node.args->GetIntArg("-checkaddrman", 0)); - DeserializeFromFuzzingInput(buffer, am); -}) FUZZ_TARGET_DESERIALIZE(blockheader_deserialize, { CBlockHeader bh; DeserializeFromFuzzingInput(buffer, bh); @@ -220,66 +247,62 @@ FUZZ_TARGET_DESERIALIZE(coins_deserialize, { Coin coin; DeserializeFromFuzzingInput(buffer, coin); }) -FUZZ_TARGET_DESERIALIZE(netaddr_deserialize, { - CNetAddr na; - DeserializeFromFuzzingInput(buffer, na); +FUZZ_TARGET(netaddr_deserialize, .init = initialize_deserialize) +{ + FuzzedDataProvider fdp{buffer.data(), buffer.size()}; + const auto maybe_na{ConsumeDeserializable<CNetAddr>(fdp, ConsumeDeserializationParams<CNetAddr::SerParams>(fdp))}; + if (!maybe_na) return; + const CNetAddr& na{*maybe_na}; if (na.IsAddrV1Compatible()) { - AssertEqualAfterSerializeDeserialize(na); + AssertEqualAfterSerializeDeserialize(na, ConsumeDeserializationParams<CNetAddr::SerParams>(fdp)); } - AssertEqualAfterSerializeDeserialize(na, INIT_PROTO_VERSION | ADDRV2_FORMAT); -}) -FUZZ_TARGET_DESERIALIZE(service_deserialize, { - CService s; - DeserializeFromFuzzingInput(buffer, s); + AssertEqualAfterSerializeDeserialize(na, CNetAddr::V2); +} +FUZZ_TARGET(service_deserialize, .init = initialize_deserialize) +{ + FuzzedDataProvider fdp{buffer.data(), buffer.size()}; + const auto ser_params{ConsumeDeserializationParams<CNetAddr::SerParams>(fdp)}; + const auto maybe_s{ConsumeDeserializable<CService>(fdp, ser_params)}; + if (!maybe_s) return; + const CService& s{*maybe_s}; if (s.IsAddrV1Compatible()) { - AssertEqualAfterSerializeDeserialize(s); + AssertEqualAfterSerializeDeserialize(s, ConsumeDeserializationParams<CNetAddr::SerParams>(fdp)); } - AssertEqualAfterSerializeDeserialize(s, INIT_PROTO_VERSION | ADDRV2_FORMAT); - CService s1; - DeserializeFromFuzzingInput(buffer, s1, INIT_PROTO_VERSION); - AssertEqualAfterSerializeDeserialize(s1, INIT_PROTO_VERSION); - assert(s1.IsAddrV1Compatible()); - CService s2; - DeserializeFromFuzzingInput(buffer, s2, INIT_PROTO_VERSION | ADDRV2_FORMAT); - AssertEqualAfterSerializeDeserialize(s2, INIT_PROTO_VERSION | ADDRV2_FORMAT); -}) + AssertEqualAfterSerializeDeserialize(s, CNetAddr::V2); + if (ser_params.enc == CNetAddr::Encoding::V1) { + assert(s.IsAddrV1Compatible()); + } +} FUZZ_TARGET_DESERIALIZE(messageheader_deserialize, { CMessageHeader mh; DeserializeFromFuzzingInput(buffer, mh); (void)mh.IsCommandValid(); }) -FUZZ_TARGET_DESERIALIZE(address_deserialize_v1_notime, { - CAddress a; - DeserializeFromFuzzingInput(buffer, a, INIT_PROTO_VERSION); - // A CAddress without nTime (as is expected under INIT_PROTO_VERSION) will roundtrip - // in all 5 formats (with/without nTime, v1/v2, network/disk) - AssertEqualAfterSerializeDeserialize(a, INIT_PROTO_VERSION); - AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION); - AssertEqualAfterSerializeDeserialize(a, 0, SER_DISK); - AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION | ADDRV2_FORMAT); - AssertEqualAfterSerializeDeserialize(a, ADDRV2_FORMAT, SER_DISK); -}) -FUZZ_TARGET_DESERIALIZE(address_deserialize_v1_withtime, { - CAddress a; - DeserializeFromFuzzingInput(buffer, a, PROTOCOL_VERSION); - // A CAddress in V1 mode will roundtrip in all 4 formats that have nTime. - AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION); - AssertEqualAfterSerializeDeserialize(a, 0, SER_DISK); - AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION | ADDRV2_FORMAT); - AssertEqualAfterSerializeDeserialize(a, ADDRV2_FORMAT, SER_DISK); -}) -FUZZ_TARGET_DESERIALIZE(address_deserialize_v2, { - CAddress a; - DeserializeFromFuzzingInput(buffer, a, PROTOCOL_VERSION | ADDRV2_FORMAT); - // A CAddress in V2 mode will roundtrip in both V2 formats, and also in the V1 formats - // with time if it's V1 compatible. - if (a.IsAddrV1Compatible()) { - AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION); - AssertEqualAfterSerializeDeserialize(a, 0, SER_DISK); +FUZZ_TARGET(address_deserialize, .init = initialize_deserialize) +{ + FuzzedDataProvider fdp{buffer.data(), buffer.size()}; + const auto ser_enc{ConsumeDeserializationParams<CNetAddr::SerParams>(fdp)}; + const auto maybe_a{ConsumeDeserializable<CAddress>(fdp, CAddress::SerParams{{ser_enc}, CAddress::Format::Network})}; + if (!maybe_a) return; + const CAddress& a{*maybe_a}; + // A CAddress in V1 mode will roundtrip + // in all 4 formats (v1/v2, network/disk) + if (ser_enc.enc == CNetAddr::Encoding::V1) { + AssertEqualAfterSerializeDeserialize(a, CAddress::V1_NETWORK); + AssertEqualAfterSerializeDeserialize(a, CAddress::V1_DISK); + AssertEqualAfterSerializeDeserialize(a, CAddress::V2_NETWORK); + AssertEqualAfterSerializeDeserialize(a, CAddress::V2_DISK); + } else { + // A CAddress in V2 mode will roundtrip in both V2 formats, and also in the V1 formats + // if it's V1 compatible. + if (a.IsAddrV1Compatible()) { + AssertEqualAfterSerializeDeserialize(a, CAddress::V1_DISK); + AssertEqualAfterSerializeDeserialize(a, CAddress::V1_NETWORK); + } + AssertEqualAfterSerializeDeserialize(a, CAddress::V2_NETWORK); + AssertEqualAfterSerializeDeserialize(a, CAddress::V2_DISK); } - AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION | ADDRV2_FORMAT); - AssertEqualAfterSerializeDeserialize(a, ADDRV2_FORMAT, SER_DISK); -}) +} FUZZ_TARGET_DESERIALIZE(inv_deserialize, { CInv i; DeserializeFromFuzzingInput(buffer, i); diff --git a/src/test/fuzz/net.cpp b/src/test/fuzz/net.cpp index ddf919f2e6..c882bd766a 100644 --- a/src/test/fuzz/net.cpp +++ b/src/test/fuzz/net.cpp @@ -53,7 +53,7 @@ FUZZ_TARGET(net, .init = initialize_net) } }, [&] { - const std::optional<CService> service_opt = ConsumeDeserializable<CService>(fuzzed_data_provider); + const std::optional<CService> service_opt = ConsumeDeserializable<CService>(fuzzed_data_provider, ConsumeDeserializationParams<CNetAddr::SerParams>(fuzzed_data_provider)); if (!service_opt) { return; } diff --git a/src/test/fuzz/script_sign.cpp b/src/test/fuzz/script_sign.cpp index b2d7d68fb4..179715b6ea 100644 --- a/src/test/fuzz/script_sign.cpp +++ b/src/test/fuzz/script_sign.cpp @@ -36,7 +36,8 @@ FUZZ_TARGET(script_sign, .init = initialize_script_sign) const std::vector<uint8_t> key = ConsumeRandomLengthByteVector(fuzzed_data_provider, 128); { - CDataStream random_data_stream = ConsumeDataStream(fuzzed_data_provider); + DataStream stream{ConsumeDataStream(fuzzed_data_provider)}; + CDataStream random_data_stream{stream, SER_NETWORK, INIT_PROTO_VERSION}; // temporary copy, to be removed along with the version flag SERIALIZE_TRANSACTION_NO_WITNESS std::map<CPubKey, KeyOriginInfo> hd_keypaths; try { DeserializeHDKeypaths(random_data_stream, key, hd_keypaths); diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index 4c2176f05b..8263cd4c08 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -71,9 +71,9 @@ template<typename B = uint8_t> return BytesToBits(ConsumeRandomLengthByteVector(fuzzed_data_provider, max_length)); } -[[nodiscard]] inline CDataStream ConsumeDataStream(FuzzedDataProvider& fuzzed_data_provider, const std::optional<size_t>& max_length = std::nullopt) noexcept +[[nodiscard]] inline DataStream ConsumeDataStream(FuzzedDataProvider& fuzzed_data_provider, const std::optional<size_t>& max_length = std::nullopt) noexcept { - return CDataStream{ConsumeRandomLengthByteVector(fuzzed_data_provider, max_length), SER_NETWORK, INIT_PROTO_VERSION}; + return DataStream{ConsumeRandomLengthByteVector(fuzzed_data_provider, max_length)}; } [[nodiscard]] inline std::vector<std::string> ConsumeRandomLengthStringVector(FuzzedDataProvider& fuzzed_data_provider, const size_t max_vector_size = 16, const size_t max_string_length = 16) noexcept @@ -97,6 +97,23 @@ template <typename T> return r; } +template <typename P> +[[nodiscard]] P ConsumeDeserializationParams(FuzzedDataProvider& fuzzed_data_provider) noexcept; + +template <typename T, typename P> +[[nodiscard]] std::optional<T> ConsumeDeserializable(FuzzedDataProvider& fuzzed_data_provider, const P& params, const std::optional<size_t>& max_length = std::nullopt) noexcept +{ + const std::vector<uint8_t> buffer{ConsumeRandomLengthByteVector(fuzzed_data_provider, max_length)}; + DataStream ds{buffer}; + T obj; + try { + ds >> WithParams(params, obj); + } catch (const std::ios_base::failure&) { + return std::nullopt; + } + return obj; +} + template <typename T> [[nodiscard]] inline std::optional<T> ConsumeDeserializable(FuzzedDataProvider& fuzzed_data_provider, const std::optional<size_t>& max_length = std::nullopt) noexcept { diff --git a/src/test/fuzz/util/net.cpp b/src/test/fuzz/util/net.cpp index 65bc336297..1545e11065 100644 --- a/src/test/fuzz/util/net.cpp +++ b/src/test/fuzz/util/net.cpp @@ -55,6 +55,27 @@ CAddress ConsumeAddress(FuzzedDataProvider& fuzzed_data_provider) noexcept return {ConsumeService(fuzzed_data_provider), ConsumeWeakEnum(fuzzed_data_provider, ALL_SERVICE_FLAGS), NodeSeconds{std::chrono::seconds{fuzzed_data_provider.ConsumeIntegral<uint32_t>()}}}; } +template <typename P> +P ConsumeDeserializationParams(FuzzedDataProvider& fuzzed_data_provider) noexcept +{ + constexpr std::array ADDR_ENCODINGS{ + CNetAddr::Encoding::V1, + CNetAddr::Encoding::V2, + }; + constexpr std::array ADDR_FORMATS{ + CAddress::Format::Disk, + CAddress::Format::Network, + }; + if constexpr (std::is_same_v<P, CNetAddr::SerParams>) { + return P{PickValue(fuzzed_data_provider, ADDR_ENCODINGS)}; + } + if constexpr (std::is_same_v<P, CAddress::SerParams>) { + return P{{PickValue(fuzzed_data_provider, ADDR_ENCODINGS)}, PickValue(fuzzed_data_provider, ADDR_FORMATS)}; + } +} +template CNetAddr::SerParams ConsumeDeserializationParams(FuzzedDataProvider&) noexcept; +template CAddress::SerParams ConsumeDeserializationParams(FuzzedDataProvider&) noexcept; + FuzzedSock::FuzzedSock(FuzzedDataProvider& fuzzed_data_provider) : m_fuzzed_data_provider{fuzzed_data_provider}, m_selectable{fuzzed_data_provider.ConsumeBool()} { diff --git a/src/test/net_tests.cpp b/src/test/net_tests.cpp index ae342a6278..295cb78b36 100644 --- a/src/test/net_tests.cpp +++ b/src/test/net_tests.cpp @@ -327,19 +327,20 @@ BOOST_AUTO_TEST_CASE(cnetaddr_tostring_canonical_ipv6) BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v1) { CNetAddr addr; - CDataStream s(SER_NETWORK, PROTOCOL_VERSION); + DataStream s{}; + const auto ser_params{CAddress::V1_NETWORK}; - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "00000000000000000000000000000000"); s.clear(); addr = LookupHost("1.2.3.4", false).value(); - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "00000000000000000000ffff01020304"); s.clear(); addr = LookupHost("1a1b:2a2b:3a3b:4a4b:5a5b:6a6b:7a7b:8a8b", false).value(); - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "1a1b2a2b3a3b4a4b5a5b6a6b7a7b8a8b"); s.clear(); @@ -347,12 +348,12 @@ BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v1) BOOST_CHECK(!addr.SetSpecial("6hzph5hv6337r6p2.onion")); BOOST_REQUIRE(addr.SetSpecial("pg6mmjiyjmcrsslvykfwnntlaru7p5svn6y2ymmju6nubxndf4pscryd.onion")); - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "00000000000000000000000000000000"); s.clear(); addr.SetInternal("a"); - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "fd6b88c08724ca978112ca1bbdcafac2"); s.clear(); } @@ -360,22 +361,20 @@ BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v1) BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v2) { CNetAddr addr; - CDataStream s(SER_NETWORK, PROTOCOL_VERSION); - // Add ADDRV2_FORMAT to the version so that the CNetAddr - // serialize method produces an address in v2 format. - s.SetVersion(s.GetVersion() | ADDRV2_FORMAT); + DataStream s{}; + const auto ser_params{CAddress::V2_NETWORK}; - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "021000000000000000000000000000000000"); s.clear(); addr = LookupHost("1.2.3.4", false).value(); - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "010401020304"); s.clear(); addr = LookupHost("1a1b:2a2b:3a3b:4a4b:5a5b:6a6b:7a7b:8a8b", false).value(); - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "02101a1b2a2b3a3b4a4b5a5b6a6b7a7b8a8b"); s.clear(); @@ -383,12 +382,12 @@ BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v2) BOOST_CHECK(!addr.SetSpecial("6hzph5hv6337r6p2.onion")); BOOST_REQUIRE(addr.SetSpecial("kpgvmscirrdqpekbqjsvw5teanhatztpp2gl6eee4zkowvwfxwenqaid.onion")); - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "042053cd5648488c4707914182655b7664034e09e66f7e8cbf1084e654eb56c5bd88"); s.clear(); BOOST_REQUIRE(addr.SetInternal("a")); - s << addr; + s << WithParams(ser_params, addr); BOOST_CHECK_EQUAL(HexStr(s), "0210fd6b88c08724ca978112ca1bbdcafac2"); s.clear(); } @@ -396,16 +395,14 @@ BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v2) BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) { CNetAddr addr; - CDataStream s(SER_NETWORK, PROTOCOL_VERSION); - // Add ADDRV2_FORMAT to the version so that the CNetAddr - // unserialize method expects an address in v2 format. - s.SetVersion(s.GetVersion() | ADDRV2_FORMAT); + DataStream s{}; + const auto ser_params{CAddress::V2_NETWORK}; // Valid IPv4. s << Span{ParseHex("01" // network type (IPv4) "04" // address length "01020304")}; // address - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(addr.IsValid()); BOOST_CHECK(addr.IsIPv4()); BOOST_CHECK(addr.IsAddrV1Compatible()); @@ -416,7 +413,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) s << Span{ParseHex("01" // network type (IPv4) "04" // address length "0102")}; // address - BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure, HasReason("end of data")); + BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure, HasReason("end of data")); BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input. s.clear(); @@ -424,7 +421,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) s << Span{ParseHex("01" // network type (IPv4) "05" // address length "01020304")}; // address - BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure, + BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure, HasReason("BIP155 IPv4 address with length 5 (should be 4)")); BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input. s.clear(); @@ -433,7 +430,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) s << Span{ParseHex("01" // network type (IPv4) "fd0102" // address length (513 as CompactSize) "01020304")}; // address - BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure, + BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure, HasReason("Address too long: 513 > 512")); BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input. s.clear(); @@ -442,7 +439,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) s << Span{ParseHex("02" // network type (IPv6) "10" // address length "0102030405060708090a0b0c0d0e0f10")}; // address - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(addr.IsValid()); BOOST_CHECK(addr.IsIPv6()); BOOST_CHECK(addr.IsAddrV1Compatible()); @@ -455,7 +452,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "10" // address length "fd6b88c08724ca978112ca1bbdcafac2")}; // address: 0xfd + sha256("bitcoin")[0:5] + // sha256(name)[0:10] - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(addr.IsInternal()); BOOST_CHECK(addr.IsAddrV1Compatible()); BOOST_CHECK_EQUAL(addr.ToStringAddr(), "zklycewkdo64v6wc.internal"); @@ -465,7 +462,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) s << Span{ParseHex("02" // network type (IPv6) "04" // address length "00")}; // address - BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure, + BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure, HasReason("BIP155 IPv6 address with length 4 (should be 16)")); BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input. s.clear(); @@ -474,7 +471,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) s << Span{ParseHex("02" // network type (IPv6) "10" // address length "00000000000000000000ffff01020304")}; // address - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(!addr.IsValid()); BOOST_REQUIRE(s.empty()); @@ -482,7 +479,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) s << Span{ParseHex("02" // network type (IPv6) "10" // address length "fd87d87eeb430102030405060708090a")}; // address - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(!addr.IsValid()); BOOST_REQUIRE(s.empty()); @@ -490,7 +487,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) s << Span{ParseHex("03" // network type (TORv2) "0a" // address length "f1f2f3f4f5f6f7f8f9fa")}; // address - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(!addr.IsValid()); BOOST_REQUIRE(s.empty()); @@ -500,7 +497,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "79bcc625184b05194975c28b66b66b04" // address "69f7f6556fb1ac3189a79b40dda32f1f" )}; - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(addr.IsValid()); BOOST_CHECK(addr.IsTor()); BOOST_CHECK(!addr.IsAddrV1Compatible()); @@ -513,7 +510,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "00" // address length "00" // address )}; - BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure, + BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure, HasReason("BIP155 TORv3 address with length 0 (should be 32)")); BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input. s.clear(); @@ -523,7 +520,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "20" // address length "a2894dabaec08c0051a481a6dac88b64" // address "f98232ae42d4b6fd2fa81952dfe36a87")}; - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(addr.IsValid()); BOOST_CHECK(addr.IsI2P()); BOOST_CHECK(!addr.IsAddrV1Compatible()); @@ -536,7 +533,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "03" // address length "00" // address )}; - BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure, + BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure, HasReason("BIP155 I2P address with length 3 (should be 32)")); BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input. s.clear(); @@ -546,7 +543,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "10" // address length "fc000001000200030004000500060007" // address )}; - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(addr.IsValid()); BOOST_CHECK(addr.IsCJDNS()); BOOST_CHECK(!addr.IsAddrV1Compatible()); @@ -558,7 +555,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "10" // address length "aa000001000200030004000500060007" // address )}; - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(addr.IsCJDNS()); BOOST_CHECK(!addr.IsValid()); BOOST_REQUIRE(s.empty()); @@ -568,7 +565,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "01" // address length "00" // address )}; - BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure, + BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure, HasReason("BIP155 CJDNS address with length 1 (should be 16)")); BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input. s.clear(); @@ -578,7 +575,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "fe00000002" // address length (CompactSize's MAX_SIZE) "01020304050607" // address )}; - BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure, + BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure, HasReason("Address too long: 33554432 > 512")); BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input. s.clear(); @@ -588,7 +585,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "04" // address length "01020304" // address )}; - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(!addr.IsValid()); BOOST_REQUIRE(s.empty()); @@ -597,7 +594,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2) "00" // address length "" // address )}; - s >> addr; + s >> WithParams(ser_params, addr); BOOST_CHECK(!addr.IsValid()); BOOST_REQUIRE(s.empty()); } @@ -852,7 +849,7 @@ BOOST_AUTO_TEST_CASE(initial_advertise_from_version_message) std::chrono::microseconds time_received_dummy{0}; const auto msg_version = - msg_maker.Make(NetMsgType::VERSION, PROTOCOL_VERSION, services, time, services, peer_us); + msg_maker.Make(NetMsgType::VERSION, PROTOCOL_VERSION, services, time, services, WithParams(CAddress::V1_NETWORK, peer_us)); CDataStream msg_version_stream{msg_version.data, SER_NETWORK, PROTOCOL_VERSION}; m_node.peerman->ProcessMessage( @@ -875,10 +872,10 @@ BOOST_AUTO_TEST_CASE(initial_advertise_from_version_message) Span<const unsigned char> data, bool is_incoming) -> void { if (!is_incoming && msg_type == "addr") { - CDataStream s(data, SER_NETWORK, PROTOCOL_VERSION); + DataStream s{data}; std::vector<CAddress> addresses; - s >> addresses; + s >> WithParams(CAddress::V1_NETWORK, addresses); for (const auto& addr : addresses) { if (addr == expected) { diff --git a/src/test/netbase_tests.cpp b/src/test/netbase_tests.cpp index 05953bfd10..e22bf7e7c0 100644 --- a/src/test/netbase_tests.cpp +++ b/src/test/netbase_tests.cpp @@ -559,35 +559,35 @@ static constexpr const char* stream_addrv2_hex = BOOST_AUTO_TEST_CASE(caddress_serialize_v1) { - CDataStream s(SER_NETWORK, PROTOCOL_VERSION); + DataStream s{}; - s << fixture_addresses; + s << WithParams(CAddress::V1_NETWORK, fixture_addresses); BOOST_CHECK_EQUAL(HexStr(s), stream_addrv1_hex); } BOOST_AUTO_TEST_CASE(caddress_unserialize_v1) { - CDataStream s(ParseHex(stream_addrv1_hex), SER_NETWORK, PROTOCOL_VERSION); + DataStream s{ParseHex(stream_addrv1_hex)}; std::vector<CAddress> addresses_unserialized; - s >> addresses_unserialized; + s >> WithParams(CAddress::V1_NETWORK, addresses_unserialized); BOOST_CHECK(fixture_addresses == addresses_unserialized); } BOOST_AUTO_TEST_CASE(caddress_serialize_v2) { - CDataStream s(SER_NETWORK, PROTOCOL_VERSION | ADDRV2_FORMAT); + DataStream s{}; - s << fixture_addresses; + s << WithParams(CAddress::V2_NETWORK, fixture_addresses); BOOST_CHECK_EQUAL(HexStr(s), stream_addrv2_hex); } BOOST_AUTO_TEST_CASE(caddress_unserialize_v2) { - CDataStream s(ParseHex(stream_addrv2_hex), SER_NETWORK, PROTOCOL_VERSION | ADDRV2_FORMAT); + DataStream s{ParseHex(stream_addrv2_hex)}; std::vector<CAddress> addresses_unserialized; - s >> addresses_unserialized; + s >> WithParams(CAddress::V2_NETWORK, addresses_unserialized); BOOST_CHECK(fixture_addresses == addresses_unserialized); } diff --git a/src/test/serialize_tests.cpp b/src/test/serialize_tests.cpp index 2e862621bc..2f2bb6698c 100644 --- a/src/test/serialize_tests.cpp +++ b/src/test/serialize_tests.cpp @@ -9,6 +9,7 @@ #include <util/strencodings.h> #include <stdint.h> +#include <string> #include <boost/test/unit_test.hpp> @@ -254,4 +255,146 @@ BOOST_AUTO_TEST_CASE(class_methods) } } +enum class BaseFormat { + RAW, + HEX, +}; + +/// (Un)serialize a number as raw byte or 2 hexadecimal chars. +class Base +{ +public: + uint8_t m_base_data; + + Base() : m_base_data(17) {} + explicit Base(uint8_t data) : m_base_data(data) {} + + template <typename Stream> + void Serialize(Stream& s) const + { + if (s.GetParams() == BaseFormat::RAW) { + s << m_base_data; + } else { + s << Span{HexStr(Span{&m_base_data, 1})}; + } + } + + template <typename Stream> + void Unserialize(Stream& s) + { + if (s.GetParams() == BaseFormat::RAW) { + s >> m_base_data; + } else { + std::string hex{"aa"}; + s >> Span{hex}.first(hex.size()); + m_base_data = TryParseHex<uint8_t>(hex).value().at(0); + } + } +}; + +class DerivedAndBaseFormat +{ +public: + BaseFormat m_base_format; + + enum class DerivedFormat { + LOWER, + UPPER, + } m_derived_format; +}; + +class Derived : public Base +{ +public: + std::string m_derived_data; + + SERIALIZE_METHODS_PARAMS(Derived, obj, DerivedAndBaseFormat, fmt) + { + READWRITE(WithParams(fmt.m_base_format, AsBase<Base>(obj))); + + if (ser_action.ForRead()) { + std::string str; + s >> str; + SER_READ(obj, obj.m_derived_data = str); + } else { + s << (fmt.m_derived_format == DerivedAndBaseFormat::DerivedFormat::LOWER ? + ToLower(obj.m_derived_data) : + ToUpper(obj.m_derived_data)); + } + } +}; + +BOOST_AUTO_TEST_CASE(with_params_base) +{ + Base b{0x0F}; + + DataStream stream; + + stream << WithParams(BaseFormat::RAW, b); + BOOST_CHECK_EQUAL(stream.str(), "\x0F"); + + b.m_base_data = 0; + stream >> WithParams(BaseFormat::RAW, b); + BOOST_CHECK_EQUAL(b.m_base_data, 0x0F); + + stream.clear(); + + stream << WithParams(BaseFormat::HEX, b); + BOOST_CHECK_EQUAL(stream.str(), "0f"); + + b.m_base_data = 0; + stream >> WithParams(BaseFormat::HEX, b); + BOOST_CHECK_EQUAL(b.m_base_data, 0x0F); +} + +BOOST_AUTO_TEST_CASE(with_params_vector_of_base) +{ + std::vector<Base> v{Base{0x0F}, Base{0xFF}}; + + DataStream stream; + + stream << WithParams(BaseFormat::RAW, v); + BOOST_CHECK_EQUAL(stream.str(), "\x02\x0F\xFF"); + + v[0].m_base_data = 0; + v[1].m_base_data = 0; + stream >> WithParams(BaseFormat::RAW, v); + BOOST_CHECK_EQUAL(v[0].m_base_data, 0x0F); + BOOST_CHECK_EQUAL(v[1].m_base_data, 0xFF); + + stream.clear(); + + stream << WithParams(BaseFormat::HEX, v); + BOOST_CHECK_EQUAL(stream.str(), "\x02" + "0fff"); + + v[0].m_base_data = 0; + v[1].m_base_data = 0; + stream >> WithParams(BaseFormat::HEX, v); + BOOST_CHECK_EQUAL(v[0].m_base_data, 0x0F); + BOOST_CHECK_EQUAL(v[1].m_base_data, 0xFF); +} + +BOOST_AUTO_TEST_CASE(with_params_derived) +{ + Derived d; + d.m_base_data = 0x0F; + d.m_derived_data = "xY"; + + DerivedAndBaseFormat fmt; + + DataStream stream; + + fmt.m_base_format = BaseFormat::RAW; + fmt.m_derived_format = DerivedAndBaseFormat::DerivedFormat::LOWER; + stream << WithParams(fmt, d); + + fmt.m_base_format = BaseFormat::HEX; + fmt.m_derived_format = DerivedAndBaseFormat::DerivedFormat::UPPER; + stream << WithParams(fmt, d); + + BOOST_CHECK_EQUAL(stream.str(), "\x0F\x02xy" + "0f\x02XY"); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/src/test/streams_tests.cpp b/src/test/streams_tests.cpp index 5232175824..589a2fd766 100644 --- a/src/test/streams_tests.cpp +++ b/src/test/streams_tests.cpp @@ -553,12 +553,12 @@ BOOST_AUTO_TEST_CASE(streams_buffered_file_rand) BOOST_AUTO_TEST_CASE(streams_hashed) { - CDataStream stream(SER_NETWORK, INIT_PROTO_VERSION); + DataStream stream{}; HashedSourceWriter hash_writer{stream}; const std::string data{"bitcoin"}; hash_writer << data; - CHashVerifier hash_verifier{&stream}; + HashVerifier hash_verifier{stream}; std::string result; hash_verifier >> result; BOOST_CHECK_EQUAL(data, result); diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp index 8015db3e80..5696f8d13c 100644 --- a/src/test/util/net.cpp +++ b/src/test/util/net.cpp @@ -33,9 +33,9 @@ void ConnmanTestMsg::Handshake(CNode& node, Using<CustomUintFormatter<8>>(remote_services), // int64_t{}, // dummy time int64_t{}, // ignored service bits - CService{}, // dummy + WithParams(CNetAddr::V1, CService{}), // dummy int64_t{}, // ignored service bits - CService{}, // ignored + WithParams(CNetAddr::V1, CService{}), // ignored uint64_t{1}, // dummy nonce std::string{}, // dummy subver int32_t{}, // dummy starting_height diff --git a/src/version.h b/src/version.h index 611a670314..204df3758b 100644 --- a/src/version.h +++ b/src/version.h @@ -36,6 +36,6 @@ static const int INVALID_CB_NO_BAN_VERSION = 70015; static const int WTXID_RELAY_VERSION = 70016; // Make sure that none of the values above collide with -// `SERIALIZE_TRANSACTION_NO_WITNESS` or `ADDRV2_FORMAT`. +// `SERIALIZE_TRANSACTION_NO_WITNESS`. #endif // BITCOIN_VERSION_H |