aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/addrdb.cpp25
-rw-r--r--src/addrdb.h7
-rw-r--r--src/addrman.cpp25
-rw-r--r--src/addrman_impl.h3
-rw-r--r--src/chain.h12
-rw-r--r--src/dbwrapper.h5
-rw-r--r--src/hash.h39
-rw-r--r--src/index/blockfilterindex.cpp1
-rw-r--r--src/index/disktxpos.h3
-rw-r--r--src/index/txindex.cpp1
-rw-r--r--src/init.cpp1
-rw-r--r--src/net.cpp708
-rw-r--r--src/net.h267
-rw-r--r--src/net_processing.cpp28
-rw-r--r--src/netaddress.h25
-rw-r--r--src/primitives/block.h17
-rw-r--r--src/protocol.h36
-rw-r--r--src/pubkey.cpp6
-rw-r--r--src/pubkey.h3
-rw-r--r--src/script/script.h2
-rw-r--r--src/serialize.h206
-rw-r--r--src/test/addrman_tests.cpp22
-rw-r--r--src/test/bip324_tests.cpp10
-rw-r--r--src/test/blockmanager_tests.cpp1
-rw-r--r--src/test/denialofservice_tests.cpp4
-rw-r--r--src/test/fuzz/addrman.cpp16
-rw-r--r--src/test/fuzz/bip324.cpp10
-rw-r--r--src/test/fuzz/deserialize.cpp151
-rw-r--r--src/test/fuzz/key.cpp16
-rw-r--r--src/test/fuzz/message.cpp4
-rw-r--r--src/test/fuzz/net.cpp2
-rw-r--r--src/test/fuzz/p2p_transport_serialization.cpp91
-rw-r--r--src/test/fuzz/rpc.cpp8
-rw-r--r--src/test/fuzz/script_sign.cpp7
-rw-r--r--src/test/fuzz/util.cpp33
-rw-r--r--src/test/fuzz/util.h24
-rw-r--r--src/test/fuzz/util/net.cpp21
-rw-r--r--src/test/net_tests.cpp608
-rw-r--r--src/test/netbase_tests.cpp16
-rw-r--r--src/test/serialize_tests.cpp143
-rw-r--r--src/test/streams_tests.cpp4
-rw-r--r--src/test/util/net.cpp8
-rw-r--r--src/validation.cpp1
-rw-r--r--src/version.h2
44 files changed, 2219 insertions, 403 deletions
diff --git a/src/addrdb.cpp b/src/addrdb.cpp
index 0fcb5ed5c9..c5b474339b 100644
--- a/src/addrdb.cpp
+++ b/src/addrdb.cpp
@@ -47,16 +47,16 @@ bool SerializeDB(Stream& stream, const Data& data)
}
template <typename Data>
-bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data& data, int version)
+bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data& data)
{
// Generate random temporary filename
const uint16_t randv{GetRand<uint16_t>()};
std::string tmpfn = strprintf("%s.%04x", prefix, randv);
- // open temp output file, and associate with CAutoFile
+ // open temp output file
fs::path pathTmp = gArgs.GetDataDirNet() / fs::u8path(tmpfn);
FILE *file = fsbridge::fopen(pathTmp, "wb");
- CAutoFile fileout(file, SER_DISK, version);
+ AutoFile fileout{file};
if (fileout.IsNull()) {
fileout.fclose();
remove(pathTmp);
@@ -86,9 +86,9 @@ bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data
}
template <typename Stream, typename Data>
-void DeserializeDB(Stream& stream, Data& data, bool fCheckSum = true)
+void DeserializeDB(Stream& stream, Data&& data, bool fCheckSum = true)
{
- CHashVerifier<Stream> verifier(&stream);
+ HashVerifier verifier{stream};
// de-serialize file header (network specific magic number) and ..
unsigned char pchMsgTmp[4];
verifier >> pchMsgTmp;
@@ -111,11 +111,10 @@ void DeserializeDB(Stream& stream, Data& data, bool fCheckSum = true)
}
template <typename Data>
-void DeserializeFileDB(const fs::path& path, Data& data, int version)
+void DeserializeFileDB(const fs::path& path, Data&& data)
{
- // open input file, and associate with CAutoFile
FILE* file = fsbridge::fopen(path, "rb");
- CAutoFile filein(file, SER_DISK, version);
+ AutoFile filein{file};
if (filein.IsNull()) {
throw DbNotFoundError{};
}
@@ -175,10 +174,10 @@ bool CBanDB::Read(banmap_t& banSet)
bool DumpPeerAddresses(const ArgsManager& args, const AddrMan& addr)
{
const auto pathAddr = args.GetDataDirNet() / "peers.dat";
- return SerializeFileDB("peers", pathAddr, addr, CLIENT_VERSION);
+ return SerializeFileDB("peers", pathAddr, addr);
}
-void ReadFromStream(AddrMan& addr, CDataStream& ssPeers)
+void ReadFromStream(AddrMan& addr, DataStream& ssPeers)
{
DeserializeDB(ssPeers, addr, false);
}
@@ -191,7 +190,7 @@ util::Result<std::unique_ptr<AddrMan>> LoadAddrman(const NetGroupManager& netgro
const auto start{SteadyClock::now()};
const auto path_addr{args.GetDataDirNet() / "peers.dat"};
try {
- DeserializeFileDB(path_addr, *addrman, CLIENT_VERSION);
+ DeserializeFileDB(path_addr, *addrman);
LogPrintf("Loaded %i addresses from peers.dat %dms\n", addrman->Size(), Ticks<std::chrono::milliseconds>(SteadyClock::now() - start));
} catch (const DbNotFoundError&) {
// Addrman can be in an inconsistent state after failure, reset it
@@ -217,14 +216,14 @@ util::Result<std::unique_ptr<AddrMan>> LoadAddrman(const NetGroupManager& netgro
void DumpAnchors(const fs::path& anchors_db_path, const std::vector<CAddress>& anchors)
{
LOG_TIME_SECONDS(strprintf("Flush %d outbound block-relay-only peer addresses to anchors.dat", anchors.size()));
- SerializeFileDB("anchors", anchors_db_path, anchors, CLIENT_VERSION | ADDRV2_FORMAT);
+ SerializeFileDB("anchors", anchors_db_path, WithParams(CAddress::V2_DISK, anchors));
}
std::vector<CAddress> ReadAnchors(const fs::path& anchors_db_path)
{
std::vector<CAddress> anchors;
try {
- DeserializeFileDB(anchors_db_path, anchors, CLIENT_VERSION | ADDRV2_FORMAT);
+ DeserializeFileDB(anchors_db_path, WithParams(CAddress::V2_DISK, anchors));
LogPrintf("Loaded %i addresses from %s\n", anchors.size(), fs::quoted(fs::PathToString(anchors_db_path.filename())));
} catch (const std::exception&) {
anchors.clear();
diff --git a/src/addrdb.h b/src/addrdb.h
index 0037495d18..cc3014dce2 100644
--- a/src/addrdb.h
+++ b/src/addrdb.h
@@ -16,12 +16,13 @@
class ArgsManager;
class AddrMan;
class CAddress;
-class CDataStream;
+class DataStream;
class NetGroupManager;
-bool DumpPeerAddresses(const ArgsManager& args, const AddrMan& addr);
/** Only used by tests. */
-void ReadFromStream(AddrMan& addr, CDataStream& ssPeers);
+void ReadFromStream(AddrMan& addr, DataStream& ssPeers);
+
+bool DumpPeerAddresses(const ArgsManager& args, const AddrMan& addr);
/** Access to the banlist database (banlist.json) */
class CBanDB
diff --git a/src/addrman.cpp b/src/addrman.cpp
index 9ccf71774a..212baab9d4 100644
--- a/src/addrman.cpp
+++ b/src/addrman.cpp
@@ -171,8 +171,7 @@ void AddrManImpl::Serialize(Stream& s_) const
*/
// Always serialize in the latest version (FILE_FORMAT).
-
- OverrideStream<Stream> s(&s_, s_.GetType(), s_.GetVersion() | ADDRV2_FORMAT);
+ ParamsStream s{CAddress::V2_DISK, s_};
s << static_cast<uint8_t>(FILE_FORMAT);
@@ -236,14 +235,8 @@ void AddrManImpl::Unserialize(Stream& s_)
Format format;
s_ >> Using<CustomUintFormatter<1>>(format);
- int stream_version = s_.GetVersion();
- if (format >= Format::V3_BIP155) {
- // Add ADDRV2_FORMAT to the version so that the CNetAddr and CAddress
- // unserialize methods know that an address in addrv2 format is coming.
- stream_version |= ADDRV2_FORMAT;
- }
-
- OverrideStream<Stream> s(&s_, s_.GetType(), stream_version);
+ const auto ser_params = (format >= Format::V3_BIP155 ? CAddress::V2_DISK : CAddress::V1_DISK);
+ ParamsStream s{ser_params, s_};
uint8_t compat;
s >> compat;
@@ -1249,12 +1242,12 @@ void AddrMan::Unserialize(Stream& s_)
}
// explicit instantiation
-template void AddrMan::Serialize(HashedSourceWriter<CAutoFile>& s) const;
-template void AddrMan::Serialize(CDataStream& s) const;
-template void AddrMan::Unserialize(CAutoFile& s);
-template void AddrMan::Unserialize(CHashVerifier<CAutoFile>& s);
-template void AddrMan::Unserialize(CDataStream& s);
-template void AddrMan::Unserialize(CHashVerifier<CDataStream>& s);
+template void AddrMan::Serialize(HashedSourceWriter<AutoFile>&) const;
+template void AddrMan::Serialize(DataStream&) const;
+template void AddrMan::Unserialize(AutoFile&);
+template void AddrMan::Unserialize(HashVerifier<AutoFile>&);
+template void AddrMan::Unserialize(DataStream&);
+template void AddrMan::Unserialize(HashVerifier<DataStream>&);
size_t AddrMan::Size(std::optional<Network> net, std::optional<bool> in_new) const
{
diff --git a/src/addrman_impl.h b/src/addrman_impl.h
index 9aff408e34..1cfaca04a3 100644
--- a/src/addrman_impl.h
+++ b/src/addrman_impl.h
@@ -65,8 +65,7 @@ public:
SERIALIZE_METHODS(AddrInfo, obj)
{
- READWRITEAS(CAddress, obj);
- READWRITE(obj.source, Using<ChronoFormatter<int64_t>>(obj.m_last_success), obj.nAttempts);
+ READWRITE(AsBase<CAddress>(obj), obj.source, Using<ChronoFormatter<int64_t>>(obj.m_last_success), obj.nAttempts);
}
AddrInfo(const CAddress &addrIn, const CNetAddr &addrSource) : CAddress(addrIn), source(addrSource)
diff --git a/src/chain.h b/src/chain.h
index 2e1fb37bec..7806720ce9 100644
--- a/src/chain.h
+++ b/src/chain.h
@@ -388,6 +388,14 @@ const CBlockIndex* LastCommonAncestor(const CBlockIndex* pa, const CBlockIndex*
/** Used to marshal pointers into hashes for db storage. */
class CDiskBlockIndex : public CBlockIndex
{
+ /** Historically CBlockLocator's version field has been written to disk
+ * streams as the client version, but the value has never been used.
+ *
+ * Hard-code to the highest client version ever written.
+ * SerParams can be used if the field requires any meaning in the future.
+ **/
+ static constexpr int DUMMY_VERSION = 259900;
+
public:
uint256 hashPrev;
@@ -404,8 +412,8 @@ public:
SERIALIZE_METHODS(CDiskBlockIndex, obj)
{
LOCK(::cs_main);
- int _nVersion = s.GetVersion();
- if (!(s.GetType() & SER_GETHASH)) READWRITE(VARINT_MODE(_nVersion, VarIntMode::NONNEGATIVE_SIGNED));
+ int _nVersion = DUMMY_VERSION;
+ READWRITE(VARINT_MODE(_nVersion, VarIntMode::NONNEGATIVE_SIGNED));
READWRITE(VARINT_MODE(obj.nHeight, VarIntMode::NONNEGATIVE_SIGNED));
READWRITE(VARINT(obj.nStatus));
diff --git a/src/dbwrapper.h b/src/dbwrapper.h
index eac9594aa1..2f7448e878 100644
--- a/src/dbwrapper.h
+++ b/src/dbwrapper.h
@@ -6,7 +6,6 @@
#define BITCOIN_DBWRAPPER_H
#include <attributes.h>
-#include <clientversion.h>
#include <serialize.h>
#include <span.h>
#include <streams.h>
@@ -167,7 +166,7 @@ public:
template<typename V> bool GetValue(V& value) {
try {
- CDataStream ssValue{GetValueImpl(), SER_DISK, CLIENT_VERSION};
+ DataStream ssValue{GetValueImpl()};
ssValue.Xor(dbwrapper_private::GetObfuscateKey(parent));
ssValue >> value;
} catch (const std::exception&) {
@@ -229,7 +228,7 @@ public:
return false;
}
try {
- CDataStream ssValue{MakeByteSpan(*strValue), SER_DISK, CLIENT_VERSION};
+ DataStream ssValue{MakeByteSpan(*strValue)};
ssValue.Xor(obfuscate_key);
ssValue >> value;
} catch (const std::exception&) {
diff --git a/src/hash.h b/src/hash.h
index 89c6f0dab9..f2b627ff4f 100644
--- a/src/hash.h
+++ b/src/hash.h
@@ -199,53 +199,20 @@ public:
}
};
-template<typename Source>
-class CHashVerifier : public CHashWriter
-{
-private:
- Source* source;
-
-public:
- explicit CHashVerifier(Source* source_) : CHashWriter(source_->GetType(), source_->GetVersion()), source(source_) {}
-
- void read(Span<std::byte> dst)
- {
- source->read(dst);
- this->write(dst);
- }
-
- void ignore(size_t nSize)
- {
- std::byte data[1024];
- while (nSize > 0) {
- size_t now = std::min<size_t>(nSize, 1024);
- read({data, now});
- nSize -= now;
- }
- }
-
- template<typename T>
- CHashVerifier<Source>& operator>>(T&& obj)
- {
- ::Unserialize(*this, obj);
- return (*this);
- }
-};
-
/** Writes data to an underlying source stream, while hashing the written data. */
template <typename Source>
-class HashedSourceWriter : public CHashWriter
+class HashedSourceWriter : public HashWriter
{
private:
Source& m_source;
public:
- explicit HashedSourceWriter(Source& source LIFETIMEBOUND) : CHashWriter{source.GetType(), source.GetVersion()}, m_source{source} {}
+ explicit HashedSourceWriter(Source& source LIFETIMEBOUND) : HashWriter{}, m_source{source} {}
void write(Span<const std::byte> src)
{
m_source.write(src);
- CHashWriter::write(src);
+ HashWriter::write(src);
}
template <typename T>
diff --git a/src/index/blockfilterindex.cpp b/src/index/blockfilterindex.cpp
index b23d66ac1d..21132d9305 100644
--- a/src/index/blockfilterindex.cpp
+++ b/src/index/blockfilterindex.cpp
@@ -4,6 +4,7 @@
#include <map>
+#include <clientversion.h>
#include <common/args.h>
#include <dbwrapper.h>
#include <hash.h>
diff --git a/src/index/disktxpos.h b/src/index/disktxpos.h
index 7718755b78..1004f7ae87 100644
--- a/src/index/disktxpos.h
+++ b/src/index/disktxpos.h
@@ -14,8 +14,7 @@ struct CDiskTxPos : public FlatFilePos
SERIALIZE_METHODS(CDiskTxPos, obj)
{
- READWRITEAS(FlatFilePos, obj);
- READWRITE(VARINT(obj.nTxOffset));
+ READWRITE(AsBase<FlatFilePos>(obj), VARINT(obj.nTxOffset));
}
CDiskTxPos(const FlatFilePos &blockIn, unsigned int nTxOffsetIn) : FlatFilePos(blockIn.nFile, blockIn.nPos), nTxOffset(nTxOffsetIn) {
diff --git a/src/index/txindex.cpp b/src/index/txindex.cpp
index 2e07a35d0d..0d4de3a53e 100644
--- a/src/index/txindex.cpp
+++ b/src/index/txindex.cpp
@@ -4,6 +4,7 @@
#include <index/txindex.h>
+#include <clientversion.h>
#include <common/args.h>
#include <index/disktxpos.h>
#include <logging.h>
diff --git a/src/init.cpp b/src/init.cpp
index 96fec92133..f0847bd4f7 100644
--- a/src/init.cpp
+++ b/src/init.cpp
@@ -19,6 +19,7 @@
#include <chain.h>
#include <chainparams.h>
#include <chainparamsbase.h>
+#include <clientversion.h>
#include <common/args.h>
#include <common/system.h>
#include <consensus/amount.h>
diff --git a/src/net.cpp b/src/net.cpp
index e66c0ec7f8..3955005dfa 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -202,7 +202,8 @@ static std::vector<CAddress> ConvertSeeds(const std::vector<uint8_t> &vSeedsIn)
const auto one_week{7 * 24h};
std::vector<CAddress> vSeedsOut;
FastRandomContext rng;
- CDataStream s(vSeedsIn, SER_NETWORK, PROTOCOL_VERSION | ADDRV2_FORMAT);
+ DataStream underlying_stream{vSeedsIn};
+ ParamsStream s{CAddress::V2_NETWORK, underlying_stream};
while (!s.eof()) {
CService endpoint;
s >> endpoint;
@@ -866,20 +867,22 @@ bool V1Transport::SetMessageToSend(CSerializedNetMsg& msg) noexcept
return true;
}
-Transport::BytesToSend V1Transport::GetBytesToSend() const noexcept
+Transport::BytesToSend V1Transport::GetBytesToSend(bool have_next_message) const noexcept
{
AssertLockNotHeld(m_send_mutex);
LOCK(m_send_mutex);
if (m_sending_header) {
return {Span{m_header_to_send}.subspan(m_bytes_sent),
- // We have more to send after the header if the message has payload.
- !m_message_to_send.data.empty(),
+ // We have more to send after the header if the message has payload, or if there
+ // is a next message after that.
+ have_next_message || !m_message_to_send.data.empty(),
m_message_to_send.m_type
};
} else {
return {Span{m_message_to_send.data}.subspan(m_bytes_sent),
- // We never have more to send after this message's payload.
- false,
+ // We only have more to send after this message's payload if there is another
+ // message.
+ have_next_message,
m_message_to_send.m_type
};
}
@@ -910,16 +913,676 @@ size_t V1Transport::GetSendMemoryUsage() const noexcept
return m_message_to_send.GetMemoryUsage();
}
+namespace {
+
+/** List of short messages as defined in BIP324, in order.
+ *
+ * Only message types that are actually implemented in this codebase need to be listed, as other
+ * messages get ignored anyway - whether we know how to decode them or not.
+ */
+const std::array<std::string, 33> V2_MESSAGE_IDS = {
+ "", // 12 bytes follow encoding the message type like in V1
+ NetMsgType::ADDR,
+ NetMsgType::BLOCK,
+ NetMsgType::BLOCKTXN,
+ NetMsgType::CMPCTBLOCK,
+ NetMsgType::FEEFILTER,
+ NetMsgType::FILTERADD,
+ NetMsgType::FILTERCLEAR,
+ NetMsgType::FILTERLOAD,
+ NetMsgType::GETBLOCKS,
+ NetMsgType::GETBLOCKTXN,
+ NetMsgType::GETDATA,
+ NetMsgType::GETHEADERS,
+ NetMsgType::HEADERS,
+ NetMsgType::INV,
+ NetMsgType::MEMPOOL,
+ NetMsgType::MERKLEBLOCK,
+ NetMsgType::NOTFOUND,
+ NetMsgType::PING,
+ NetMsgType::PONG,
+ NetMsgType::SENDCMPCT,
+ NetMsgType::TX,
+ NetMsgType::GETCFILTERS,
+ NetMsgType::CFILTER,
+ NetMsgType::GETCFHEADERS,
+ NetMsgType::CFHEADERS,
+ NetMsgType::GETCFCHECKPT,
+ NetMsgType::CFCHECKPT,
+ NetMsgType::ADDRV2,
+ // Unimplemented message types that are assigned in BIP324:
+ "",
+ "",
+ "",
+ ""
+};
+
+class V2MessageMap
+{
+ std::unordered_map<std::string, uint8_t> m_map;
+
+public:
+ V2MessageMap() noexcept
+ {
+ for (size_t i = 1; i < std::size(V2_MESSAGE_IDS); ++i) {
+ m_map.emplace(V2_MESSAGE_IDS[i], i);
+ }
+ }
+
+ std::optional<uint8_t> operator()(const std::string& message_name) const noexcept
+ {
+ auto it = m_map.find(message_name);
+ if (it == m_map.end()) return std::nullopt;
+ return it->second;
+ }
+};
+
+const V2MessageMap V2_MESSAGE_MAP;
+
+} // namespace
+
+V2Transport::V2Transport(NodeId nodeid, bool initiating, int type_in, int version_in) noexcept :
+ m_cipher{}, m_initiating{initiating}, m_nodeid{nodeid},
+ m_v1_fallback{nodeid, type_in, version_in}, m_recv_type{type_in}, m_recv_version{version_in},
+ m_recv_state{initiating ? RecvState::KEY : RecvState::KEY_MAYBE_V1},
+ m_send_state{initiating ? SendState::AWAITING_KEY : SendState::MAYBE_V1}
+{
+ // Construct garbage (including its length) using a FastRandomContext.
+ FastRandomContext rng;
+ size_t garbage_len = rng.randrange(MAX_GARBAGE_LEN + 1);
+ // Initialize the send buffer with ellswift pubkey + garbage.
+ m_send_buffer.resize(EllSwiftPubKey::size() + garbage_len);
+ std::copy(std::begin(m_cipher.GetOurPubKey()), std::end(m_cipher.GetOurPubKey()), MakeWritableByteSpan(m_send_buffer).begin());
+ rng.fillrand(MakeWritableByteSpan(m_send_buffer).subspan(EllSwiftPubKey::size()));
+}
+
+V2Transport::V2Transport(NodeId nodeid, bool initiating, int type_in, int version_in, const CKey& key, Span<const std::byte> ent32, Span<const uint8_t> garbage) noexcept :
+ m_cipher{key, ent32}, m_initiating{initiating}, m_nodeid{nodeid},
+ m_v1_fallback{nodeid, type_in, version_in}, m_recv_type{type_in}, m_recv_version{version_in},
+ m_recv_state{initiating ? RecvState::KEY : RecvState::KEY_MAYBE_V1},
+ m_send_state{initiating ? SendState::AWAITING_KEY : SendState::MAYBE_V1}
+{
+ assert(garbage.size() <= MAX_GARBAGE_LEN);
+ // Initialize the send buffer with ellswift pubkey + provided garbage.
+ m_send_buffer.resize(EllSwiftPubKey::size() + garbage.size());
+ std::copy(std::begin(m_cipher.GetOurPubKey()), std::end(m_cipher.GetOurPubKey()), MakeWritableByteSpan(m_send_buffer).begin());
+ std::copy(garbage.begin(), garbage.end(), m_send_buffer.begin() + EllSwiftPubKey::size());
+}
+
+void V2Transport::SetReceiveState(RecvState recv_state) noexcept
+{
+ AssertLockHeld(m_recv_mutex);
+ // Enforce allowed state transitions.
+ switch (m_recv_state) {
+ case RecvState::KEY_MAYBE_V1:
+ Assume(recv_state == RecvState::KEY || recv_state == RecvState::V1);
+ break;
+ case RecvState::KEY:
+ Assume(recv_state == RecvState::GARB_GARBTERM);
+ break;
+ case RecvState::GARB_GARBTERM:
+ Assume(recv_state == RecvState::GARBAUTH);
+ break;
+ case RecvState::GARBAUTH:
+ Assume(recv_state == RecvState::VERSION);
+ break;
+ case RecvState::VERSION:
+ Assume(recv_state == RecvState::APP);
+ break;
+ case RecvState::APP:
+ Assume(recv_state == RecvState::APP_READY);
+ break;
+ case RecvState::APP_READY:
+ Assume(recv_state == RecvState::APP);
+ break;
+ case RecvState::V1:
+ Assume(false); // V1 state cannot be left
+ break;
+ }
+ // Change state.
+ m_recv_state = recv_state;
+}
+
+void V2Transport::SetSendState(SendState send_state) noexcept
+{
+ AssertLockHeld(m_send_mutex);
+ // Enforce allowed state transitions.
+ switch (m_send_state) {
+ case SendState::MAYBE_V1:
+ Assume(send_state == SendState::V1 || send_state == SendState::AWAITING_KEY);
+ break;
+ case SendState::AWAITING_KEY:
+ Assume(send_state == SendState::READY);
+ break;
+ case SendState::READY:
+ case SendState::V1:
+ Assume(false); // Final states
+ break;
+ }
+ // Change state.
+ m_send_state = send_state;
+}
+
+bool V2Transport::ReceivedMessageComplete() const noexcept
+{
+ AssertLockNotHeld(m_recv_mutex);
+ LOCK(m_recv_mutex);
+ if (m_recv_state == RecvState::V1) return m_v1_fallback.ReceivedMessageComplete();
+
+ return m_recv_state == RecvState::APP_READY;
+}
+
+void V2Transport::ProcessReceivedMaybeV1Bytes() noexcept
+{
+ AssertLockHeld(m_recv_mutex);
+ AssertLockNotHeld(m_send_mutex);
+ Assume(m_recv_state == RecvState::KEY_MAYBE_V1);
+ // We still have to determine if this is a v1 or v2 connection. The bytes being received could
+ // be the beginning of either a v1 packet (network magic + "version\x00"), or of a v2 public
+ // key. BIP324 specifies that a mismatch with this 12-byte string should trigger sending of the
+ // key.
+ std::array<uint8_t, V1_PREFIX_LEN> v1_prefix = {0, 0, 0, 0, 'v', 'e', 'r', 's', 'i', 'o', 'n', 0};
+ std::copy(std::begin(Params().MessageStart()), std::end(Params().MessageStart()), v1_prefix.begin());
+ Assume(m_recv_buffer.size() <= v1_prefix.size());
+ if (!std::equal(m_recv_buffer.begin(), m_recv_buffer.end(), v1_prefix.begin())) {
+ // Mismatch with v1 prefix, so we can assume a v2 connection.
+ SetReceiveState(RecvState::KEY); // Convert to KEY state, leaving received bytes around.
+ // Transition the sender to AWAITING_KEY state (if not already).
+ LOCK(m_send_mutex);
+ SetSendState(SendState::AWAITING_KEY);
+ } else if (m_recv_buffer.size() == v1_prefix.size()) {
+ // Full match with the v1 prefix, so fall back to v1 behavior.
+ LOCK(m_send_mutex);
+ Span<const uint8_t> feedback{m_recv_buffer};
+ // Feed already received bytes to v1 transport. It should always accept these, because it's
+ // less than the size of a v1 header, and these are the first bytes fed to m_v1_fallback.
+ bool ret = m_v1_fallback.ReceivedBytes(feedback);
+ Assume(feedback.empty());
+ Assume(ret);
+ SetReceiveState(RecvState::V1);
+ SetSendState(SendState::V1);
+ // Reset v2 transport buffers to save memory.
+ m_recv_buffer = {};
+ m_send_buffer = {};
+ } else {
+ // We have not received enough to distinguish v1 from v2 yet. Wait until more bytes come.
+ }
+}
+
+bool V2Transport::ProcessReceivedKeyBytes() noexcept
+{
+ AssertLockHeld(m_recv_mutex);
+ AssertLockNotHeld(m_send_mutex);
+ Assume(m_recv_state == RecvState::KEY);
+ Assume(m_recv_buffer.size() <= EllSwiftPubKey::size());
+
+ // As a special exception, if bytes 4-16 of the key on a responder connection match the
+ // corresponding bytes of a V1 version message, but bytes 0-4 don't match the network magic
+ // (if they did, we'd have switched to V1 state already), assume this is a peer from
+ // another network, and disconnect them. They will almost certainly disconnect us too when
+ // they receive our uniformly random key and garbage, but detecting this case specially
+ // means we can log it.
+ static constexpr std::array<uint8_t, 12> MATCH = {'v', 'e', 'r', 's', 'i', 'o', 'n', 0, 0, 0, 0, 0};
+ static constexpr size_t OFFSET = sizeof(CMessageHeader::MessageStartChars);
+ if (!m_initiating && m_recv_buffer.size() >= OFFSET + MATCH.size()) {
+ if (std::equal(MATCH.begin(), MATCH.end(), m_recv_buffer.begin() + OFFSET)) {
+ LogPrint(BCLog::NET, "V2 transport error: V1 peer with wrong MessageStart %s\n",
+ HexStr(Span(m_recv_buffer).first(OFFSET)));
+ return false;
+ }
+ }
+
+ if (m_recv_buffer.size() == EllSwiftPubKey::size()) {
+ // Other side's key has been fully received, and can now be Diffie-Hellman combined with
+ // our key to initialize the encryption ciphers.
+
+ // Initialize the ciphers.
+ EllSwiftPubKey ellswift(MakeByteSpan(m_recv_buffer));
+ LOCK(m_send_mutex);
+ m_cipher.Initialize(ellswift, m_initiating);
+
+ // Switch receiver state to GARB_GARBTERM.
+ SetReceiveState(RecvState::GARB_GARBTERM);
+ m_recv_buffer.clear();
+
+ // Switch sender state to READY.
+ SetSendState(SendState::READY);
+
+ // Append the garbage terminator to the send buffer.
+ size_t garbage_len = m_send_buffer.size() - EllSwiftPubKey::size();
+ m_send_buffer.resize(m_send_buffer.size() + BIP324Cipher::GARBAGE_TERMINATOR_LEN);
+ std::copy(m_cipher.GetSendGarbageTerminator().begin(),
+ m_cipher.GetSendGarbageTerminator().end(),
+ MakeWritableByteSpan(m_send_buffer).last(BIP324Cipher::GARBAGE_TERMINATOR_LEN).begin());
+
+ // Construct garbage authentication packet in the send buffer (using the garbage data which
+ // is still there).
+ m_send_buffer.resize(m_send_buffer.size() + BIP324Cipher::EXPANSION);
+ m_cipher.Encrypt(
+ /*contents=*/{},
+ /*aad=*/MakeByteSpan(m_send_buffer).subspan(EllSwiftPubKey::size(), garbage_len),
+ /*ignore=*/false,
+ /*output=*/MakeWritableByteSpan(m_send_buffer).last(BIP324Cipher::EXPANSION));
+
+ // Construct version packet in the send buffer.
+ m_send_buffer.resize(m_send_buffer.size() + BIP324Cipher::EXPANSION + VERSION_CONTENTS.size());
+ m_cipher.Encrypt(
+ /*contents=*/VERSION_CONTENTS,
+ /*aad=*/{},
+ /*ignore=*/false,
+ /*output=*/MakeWritableByteSpan(m_send_buffer).last(BIP324Cipher::EXPANSION + VERSION_CONTENTS.size()));
+ } else {
+ // We still have to receive more key bytes.
+ }
+ return true;
+}
+
+bool V2Transport::ProcessReceivedGarbageBytes() noexcept
+{
+ AssertLockHeld(m_recv_mutex);
+ Assume(m_recv_state == RecvState::GARB_GARBTERM);
+ Assume(m_recv_buffer.size() <= MAX_GARBAGE_LEN + BIP324Cipher::GARBAGE_TERMINATOR_LEN);
+ if (m_recv_buffer.size() >= BIP324Cipher::GARBAGE_TERMINATOR_LEN) {
+ if (MakeByteSpan(m_recv_buffer).last(BIP324Cipher::GARBAGE_TERMINATOR_LEN) == m_cipher.GetReceiveGarbageTerminator()) {
+ // Garbage terminator received. Switch to receiving garbage authentication packet.
+ m_recv_garbage = std::move(m_recv_buffer);
+ m_recv_garbage.resize(m_recv_garbage.size() - BIP324Cipher::GARBAGE_TERMINATOR_LEN);
+ m_recv_buffer.clear();
+ SetReceiveState(RecvState::GARBAUTH);
+ } else if (m_recv_buffer.size() == MAX_GARBAGE_LEN + BIP324Cipher::GARBAGE_TERMINATOR_LEN) {
+ // We've reached the maximum length for garbage + garbage terminator, and the
+ // terminator still does not match. Abort.
+ LogPrint(BCLog::NET, "V2 transport error: missing garbage terminator, peer=%d\n", m_nodeid);
+ return false;
+ } else {
+ // We still need to receive more garbage and/or garbage terminator bytes.
+ }
+ } else {
+ // We have less than GARBAGE_TERMINATOR_LEN (16) bytes, so we certainly need to receive
+ // more first.
+ }
+ return true;
+}
+
+bool V2Transport::ProcessReceivedPacketBytes() noexcept
+{
+ AssertLockHeld(m_recv_mutex);
+ Assume(m_recv_state == RecvState::GARBAUTH || m_recv_state == RecvState::VERSION ||
+ m_recv_state == RecvState::APP);
+
+ // The maximum permitted contents length for a packet, consisting of:
+ // - 0x00 byte: indicating long message type encoding
+ // - 12 bytes of message type
+ // - payload
+ static constexpr size_t MAX_CONTENTS_LEN =
+ 1 + CMessageHeader::COMMAND_SIZE +
+ std::min<size_t>(MAX_SIZE, MAX_PROTOCOL_MESSAGE_LENGTH);
+
+ if (m_recv_buffer.size() == BIP324Cipher::LENGTH_LEN) {
+ // Length descriptor received.
+ m_recv_len = m_cipher.DecryptLength(MakeByteSpan(m_recv_buffer));
+ if (m_recv_len > MAX_CONTENTS_LEN) {
+ LogPrint(BCLog::NET, "V2 transport error: packet too large (%u bytes), peer=%d\n", m_recv_len, m_nodeid);
+ return false;
+ }
+ } else if (m_recv_buffer.size() > BIP324Cipher::LENGTH_LEN && m_recv_buffer.size() == m_recv_len + BIP324Cipher::EXPANSION) {
+ // Ciphertext received, decrypt it into m_recv_decode_buffer.
+ // Note that it is impossible to reach this branch without hitting the branch above first,
+ // as GetMaxBytesToProcess only allows up to LENGTH_LEN into the buffer before that point.
+ m_recv_decode_buffer.resize(m_recv_len);
+ bool ignore{false};
+ Span<const std::byte> aad;
+ if (m_recv_state == RecvState::GARBAUTH) aad = MakeByteSpan(m_recv_garbage);
+ bool ret = m_cipher.Decrypt(
+ /*input=*/MakeByteSpan(m_recv_buffer).subspan(BIP324Cipher::LENGTH_LEN),
+ /*aad=*/aad,
+ /*ignore=*/ignore,
+ /*contents=*/MakeWritableByteSpan(m_recv_decode_buffer));
+ if (!ret) {
+ LogPrint(BCLog::NET, "V2 transport error: packet decryption failure (%u bytes), peer=%d\n", m_recv_len, m_nodeid);
+ return false;
+ }
+ // Feed the last 4 bytes of the Poly1305 authentication tag (and its timing) into our RNG.
+ RandAddEvent(ReadLE32(m_recv_buffer.data() + m_recv_buffer.size() - 4));
+
+ // At this point we have a valid packet decrypted into m_recv_decode_buffer. Depending on
+ // the current state, decide what to do with it.
+ switch (m_recv_state) {
+ case RecvState::GARBAUTH:
+ // Ignore flag does not matter for garbage authentication. Any valid packet functions
+ // as authentication. Receive and process the version packet next.
+ SetReceiveState(RecvState::VERSION);
+ m_recv_garbage = {};
+ break;
+ case RecvState::VERSION:
+ if (!ignore) {
+ // Version message received; transition to application phase. The contents is
+ // ignored, but can be used for future extensions.
+ SetReceiveState(RecvState::APP);
+ }
+ break;
+ case RecvState::APP:
+ if (!ignore) {
+ // Application message decrypted correctly. It can be extracted using GetMessage().
+ SetReceiveState(RecvState::APP_READY);
+ }
+ break;
+ default:
+ // Any other state is invalid (this function should not have been called).
+ Assume(false);
+ }
+ // Wipe the receive buffer where the next packet will be received into.
+ m_recv_buffer = {};
+ // In all but APP_READY state, we can wipe the decoded contents.
+ if (m_recv_state != RecvState::APP_READY) m_recv_decode_buffer = {};
+ } else {
+ // We either have less than 3 bytes, so we don't know the packet's length yet, or more
+ // than 3 bytes but less than the packet's full ciphertext. Wait until those arrive.
+ }
+ return true;
+}
+
+size_t V2Transport::GetMaxBytesToProcess() noexcept
+{
+ AssertLockHeld(m_recv_mutex);
+ switch (m_recv_state) {
+ case RecvState::KEY_MAYBE_V1:
+ // During the KEY_MAYBE_V1 state we do not allow more than the length of v1 prefix into the
+ // receive buffer.
+ Assume(m_recv_buffer.size() <= V1_PREFIX_LEN);
+ // As long as we're not sure if this is a v1 or v2 connection, don't receive more than what
+ // is strictly necessary to distinguish the two (12 bytes). If we permitted more than
+ // the v1 header size (24 bytes), we may not be able to feed the already-received bytes
+ // back into the m_v1_fallback V1 transport.
+ return V1_PREFIX_LEN - m_recv_buffer.size();
+ case RecvState::KEY:
+ // During the KEY state, we only allow the 64-byte key into the receive buffer.
+ Assume(m_recv_buffer.size() <= EllSwiftPubKey::size());
+ // As long as we have not received the other side's public key, don't receive more than
+ // that (64 bytes), as garbage follows, and locating the garbage terminator requires the
+ // key exchange first.
+ return EllSwiftPubKey::size() - m_recv_buffer.size();
+ case RecvState::GARB_GARBTERM:
+ // Process garbage bytes one by one (because terminator may appear anywhere).
+ return 1;
+ case RecvState::GARBAUTH:
+ case RecvState::VERSION:
+ case RecvState::APP:
+ // These three states all involve decoding a packet. Process the length descriptor first,
+ // so that we know where the current packet ends (and we don't process bytes from the next
+ // packet or decoy yet). Then, process the ciphertext bytes of the current packet.
+ if (m_recv_buffer.size() < BIP324Cipher::LENGTH_LEN) {
+ return BIP324Cipher::LENGTH_LEN - m_recv_buffer.size();
+ } else {
+ // Note that BIP324Cipher::EXPANSION is the total difference between contents size
+ // and encoded packet size, which includes the 3 bytes due to the packet length.
+ // When transitioning from receiving the packet length to receiving its ciphertext,
+ // the encrypted packet length is left in the receive buffer.
+ return BIP324Cipher::EXPANSION + m_recv_len - m_recv_buffer.size();
+ }
+ case RecvState::APP_READY:
+ // No bytes can be processed until GetMessage() is called.
+ return 0;
+ case RecvState::V1:
+ // Not allowed (must be dealt with by the caller).
+ Assume(false);
+ return 0;
+ }
+ Assume(false); // unreachable
+ return 0;
+}
+
+bool V2Transport::ReceivedBytes(Span<const uint8_t>& msg_bytes) noexcept
+{
+ AssertLockNotHeld(m_recv_mutex);
+ /** How many bytes to allocate in the receive buffer at most above what is received so far. */
+ static constexpr size_t MAX_RESERVE_AHEAD = 256 * 1024;
+
+ LOCK(m_recv_mutex);
+ if (m_recv_state == RecvState::V1) return m_v1_fallback.ReceivedBytes(msg_bytes);
+
+ // Process the provided bytes in msg_bytes in a loop. In each iteration a nonzero number of
+ // bytes (decided by GetMaxBytesToProcess) are taken from the beginning om msg_bytes, and
+ // appended to m_recv_buffer. Then, depending on the receiver state, one of the
+ // ProcessReceived*Bytes functions is called to process the bytes in that buffer.
+ while (!msg_bytes.empty()) {
+ // Decide how many bytes to copy from msg_bytes to m_recv_buffer.
+ size_t max_read = GetMaxBytesToProcess();
+
+ // Reserve space in the buffer if there is not enough.
+ if (m_recv_buffer.size() + std::min(msg_bytes.size(), max_read) > m_recv_buffer.capacity()) {
+ switch (m_recv_state) {
+ case RecvState::KEY_MAYBE_V1:
+ case RecvState::KEY:
+ case RecvState::GARB_GARBTERM:
+ // During the initial states (key/garbage), allocate once to fit the maximum (4111
+ // bytes).
+ m_recv_buffer.reserve(MAX_GARBAGE_LEN + BIP324Cipher::GARBAGE_TERMINATOR_LEN);
+ break;
+ case RecvState::GARBAUTH:
+ case RecvState::VERSION:
+ case RecvState::APP: {
+ // During states where a packet is being received, as much as is expected but never
+ // more than MAX_RESERVE_AHEAD bytes in addition to what is received so far.
+ // This means attackers that want to cause us to waste allocated memory are limited
+ // to MAX_RESERVE_AHEAD above the largest allowed message contents size, and to
+ // MAX_RESERVE_AHEAD more than they've actually sent us.
+ size_t alloc_add = std::min(max_read, msg_bytes.size() + MAX_RESERVE_AHEAD);
+ m_recv_buffer.reserve(m_recv_buffer.size() + alloc_add);
+ break;
+ }
+ case RecvState::APP_READY:
+ // The buffer is empty in this state.
+ Assume(m_recv_buffer.empty());
+ break;
+ case RecvState::V1:
+ // Should have bailed out above.
+ Assume(false);
+ break;
+ }
+ }
+
+ // Can't read more than provided input.
+ max_read = std::min(msg_bytes.size(), max_read);
+ // Copy data to buffer.
+ m_recv_buffer.insert(m_recv_buffer.end(), UCharCast(msg_bytes.data()), UCharCast(msg_bytes.data() + max_read));
+ msg_bytes = msg_bytes.subspan(max_read);
+
+ // Process data in the buffer.
+ switch (m_recv_state) {
+ case RecvState::KEY_MAYBE_V1:
+ ProcessReceivedMaybeV1Bytes();
+ if (m_recv_state == RecvState::V1) return true;
+ break;
+
+ case RecvState::KEY:
+ if (!ProcessReceivedKeyBytes()) return false;
+ break;
+
+ case RecvState::GARB_GARBTERM:
+ if (!ProcessReceivedGarbageBytes()) return false;
+ break;
+
+ case RecvState::GARBAUTH:
+ case RecvState::VERSION:
+ case RecvState::APP:
+ if (!ProcessReceivedPacketBytes()) return false;
+ break;
+
+ case RecvState::APP_READY:
+ return true;
+
+ case RecvState::V1:
+ // We should have bailed out before.
+ Assume(false);
+ break;
+ }
+ // Make sure we have made progress before continuing.
+ Assume(max_read > 0);
+ }
+
+ return true;
+}
+
+std::optional<std::string> V2Transport::GetMessageType(Span<const uint8_t>& contents) noexcept
+{
+ if (contents.size() == 0) return std::nullopt; // Empty contents
+ uint8_t first_byte = contents[0];
+ contents = contents.subspan(1); // Strip first byte.
+
+ if (first_byte != 0) {
+ // Short (1 byte) encoding.
+ if (first_byte < std::size(V2_MESSAGE_IDS)) {
+ // Valid short message id.
+ return V2_MESSAGE_IDS[first_byte];
+ } else {
+ // Unknown short message id.
+ return std::nullopt;
+ }
+ }
+
+ if (contents.size() < CMessageHeader::COMMAND_SIZE) {
+ return std::nullopt; // Long encoding needs 12 message type bytes.
+ }
+
+ size_t msg_type_len{0};
+ while (msg_type_len < CMessageHeader::COMMAND_SIZE && contents[msg_type_len] != 0) {
+ // Verify that message type bytes before the first 0x00 are in range.
+ if (contents[msg_type_len] < ' ' || contents[msg_type_len] > 0x7F) {
+ return {};
+ }
+ ++msg_type_len;
+ }
+ std::string ret{reinterpret_cast<const char*>(contents.data()), msg_type_len};
+ while (msg_type_len < CMessageHeader::COMMAND_SIZE) {
+ // Verify that message type bytes after the first 0x00 are also 0x00.
+ if (contents[msg_type_len] != 0) return {};
+ ++msg_type_len;
+ }
+ // Strip message type bytes of contents.
+ contents = contents.subspan(CMessageHeader::COMMAND_SIZE);
+ return {std::move(ret)};
+}
+
+CNetMessage V2Transport::GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) noexcept
+{
+ AssertLockNotHeld(m_recv_mutex);
+ LOCK(m_recv_mutex);
+ if (m_recv_state == RecvState::V1) return m_v1_fallback.GetReceivedMessage(time, reject_message);
+
+ Assume(m_recv_state == RecvState::APP_READY);
+ Span<const uint8_t> contents{m_recv_decode_buffer};
+ auto msg_type = GetMessageType(contents);
+ CDataStream ret(m_recv_type, m_recv_version);
+ CNetMessage msg{std::move(ret)};
+ // Note that BIP324Cipher::EXPANSION also includes the length descriptor size.
+ msg.m_raw_message_size = m_recv_decode_buffer.size() + BIP324Cipher::EXPANSION;
+ if (msg_type) {
+ reject_message = false;
+ msg.m_type = std::move(*msg_type);
+ msg.m_time = time;
+ msg.m_message_size = contents.size();
+ msg.m_recv.resize(contents.size());
+ std::copy(contents.begin(), contents.end(), UCharCast(msg.m_recv.data()));
+ } else {
+ LogPrint(BCLog::NET, "V2 transport error: invalid message type (%u bytes contents), peer=%d\n", m_recv_decode_buffer.size(), m_nodeid);
+ reject_message = true;
+ }
+ m_recv_decode_buffer = {};
+ SetReceiveState(RecvState::APP);
+
+ return msg;
+}
+
+bool V2Transport::SetMessageToSend(CSerializedNetMsg& msg) noexcept
+{
+ AssertLockNotHeld(m_send_mutex);
+ LOCK(m_send_mutex);
+ if (m_send_state == SendState::V1) return m_v1_fallback.SetMessageToSend(msg);
+ // We only allow adding a new message to be sent when in the READY state (so the packet cipher
+ // is available) and the send buffer is empty. This limits the number of messages in the send
+ // buffer to just one, and leaves the responsibility for queueing them up to the caller.
+ if (!(m_send_state == SendState::READY && m_send_buffer.empty())) return false;
+ // Construct contents (encoding message type + payload).
+ std::vector<uint8_t> contents;
+ auto short_message_id = V2_MESSAGE_MAP(msg.m_type);
+ if (short_message_id) {
+ contents.resize(1 + msg.data.size());
+ contents[0] = *short_message_id;
+ std::copy(msg.data.begin(), msg.data.end(), contents.begin() + 1);
+ } else {
+ // Initialize with zeroes, and then write the message type string starting at offset 1.
+ // This means contents[0] and the unused positions in contents[1..13] remain 0x00.
+ contents.resize(1 + CMessageHeader::COMMAND_SIZE + msg.data.size(), 0);
+ std::copy(msg.m_type.begin(), msg.m_type.end(), contents.data() + 1);
+ std::copy(msg.data.begin(), msg.data.end(), contents.begin() + 1 + CMessageHeader::COMMAND_SIZE);
+ }
+ // Construct ciphertext in send buffer.
+ m_send_buffer.resize(contents.size() + BIP324Cipher::EXPANSION);
+ m_cipher.Encrypt(MakeByteSpan(contents), {}, false, MakeWritableByteSpan(m_send_buffer));
+ m_send_type = msg.m_type;
+ // Release memory
+ msg.data = {};
+ return true;
+}
+
+Transport::BytesToSend V2Transport::GetBytesToSend(bool have_next_message) const noexcept
+{
+ AssertLockNotHeld(m_send_mutex);
+ LOCK(m_send_mutex);
+ if (m_send_state == SendState::V1) return m_v1_fallback.GetBytesToSend(have_next_message);
+
+ // We do not send anything in MAYBE_V1 state (as we don't know if the peer is v1 or v2),
+ // despite there being data in the send buffer in that state.
+ if (m_send_state == SendState::MAYBE_V1) return {{}, false, m_send_type};
+ Assume(m_send_pos <= m_send_buffer.size());
+ return {
+ Span{m_send_buffer}.subspan(m_send_pos),
+ // We only have more to send after the current m_send_buffer if there is a (next)
+ // message to be sent, and we're capable of sending packets. */
+ have_next_message && m_send_state == SendState::READY,
+ m_send_type
+ };
+}
+
+void V2Transport::MarkBytesSent(size_t bytes_sent) noexcept
+{
+ AssertLockNotHeld(m_send_mutex);
+ LOCK(m_send_mutex);
+ if (m_send_state == SendState::V1) return m_v1_fallback.MarkBytesSent(bytes_sent);
+
+ m_send_pos += bytes_sent;
+ Assume(m_send_pos <= m_send_buffer.size());
+ // Only wipe the buffer when everything is sent in the READY state. In the AWAITING_KEY state
+ // we still need the garbage that's in the send buffer to construct the garbage authentication
+ // packet.
+ if (m_send_state == SendState::READY && m_send_pos == m_send_buffer.size()) {
+ m_send_pos = 0;
+ m_send_buffer = {};
+ }
+}
+
+size_t V2Transport::GetSendMemoryUsage() const noexcept
+{
+ AssertLockNotHeld(m_send_mutex);
+ LOCK(m_send_mutex);
+ if (m_send_state == SendState::V1) return m_v1_fallback.GetSendMemoryUsage();
+
+ return sizeof(m_send_buffer) + memusage::DynamicUsage(m_send_buffer);
+}
+
std::pair<size_t, bool> CConnman::SocketSendData(CNode& node) const
{
auto it = node.vSendMsg.begin();
size_t nSentSize = 0;
bool data_left{false}; //!< second return value (whether unsent data remains)
+ std::optional<bool> expected_more;
while (true) {
if (it != node.vSendMsg.end()) {
// If possible, move one message from the send queue to the transport. This fails when
- // there is an existing message still being sent.
+ // there is an existing message still being sent, or (for v2 transports) when the
+ // handshake has not yet completed.
size_t memusage = it->GetMemoryUsage();
if (node.m_transport->SetMessageToSend(*it)) {
// Update memory usage of send buffer (as *it will be deleted).
@@ -927,7 +1590,12 @@ std::pair<size_t, bool> CConnman::SocketSendData(CNode& node) const
++it;
}
}
- const auto& [data, more, msg_type] = node.m_transport->GetBytesToSend();
+ const auto& [data, more, msg_type] = node.m_transport->GetBytesToSend(it != node.vSendMsg.end());
+ // We rely on the 'more' value returned by GetBytesToSend to correctly predict whether more
+ // bytes are still to be sent, to correctly set the MSG_MORE flag. As a sanity check,
+ // verify that the previously returned 'more' was correct.
+ if (expected_more.has_value()) Assume(!data.empty() == *expected_more);
+ expected_more = more;
data_left = !data.empty(); // will be overwritten on next loop if all of data gets sent
int nBytes = 0;
if (!data.empty()) {
@@ -940,9 +1608,7 @@ std::pair<size_t, bool> CConnman::SocketSendData(CNode& node) const
}
int flags = MSG_NOSIGNAL | MSG_DONTWAIT;
#ifdef MSG_MORE
- // We have more to send if either the transport itself has more, or if we have more
- // messages to send.
- if (more || it != node.vSendMsg.end()) {
+ if (more) {
flags |= MSG_MORE;
}
#endif
@@ -1322,9 +1988,10 @@ Sock::EventsPerSock CConnman::GenerateWaitSockets(Span<CNode* const> nodes)
{
LOCK(pnode->cs_vSend);
// Sending is possible if either there are bytes to send right now, or if there will be
- // once a potential message from vSendMsg is handed to the transport.
- const auto& [to_send, _more, _msg_type] = pnode->m_transport->GetBytesToSend();
- select_send = !to_send.empty() || !pnode->vSendMsg.empty();
+ // once a potential message from vSendMsg is handed to the transport. GetBytesToSend
+ // determines both of these in a single call.
+ const auto& [to_send, more, _msg_type] = pnode->m_transport->GetBytesToSend(!pnode->vSendMsg.empty());
+ select_send = !to_send.empty() || more;
}
if (!select_recv && !select_send) continue;
@@ -3006,7 +3673,10 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg)
size_t nBytesSent = 0;
{
LOCK(pnode->cs_vSend);
- const auto& [to_send, _more, _msg_type] = pnode->m_transport->GetBytesToSend();
+ // Check if the transport still has unsent bytes, and indicate to it that we're about to
+ // give it a message to send.
+ const auto& [to_send, more, _msg_type] =
+ pnode->m_transport->GetBytesToSend(/*have_next_message=*/true);
const bool queue_was_empty{to_send.empty() && pnode->vSendMsg.empty()};
// Update memory usage of send buffer.
@@ -3015,10 +3685,14 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg)
// Move message to vSendMsg queue.
pnode->vSendMsg.push_back(std::move(msg));
- // If there was nothing to send before, attempt "optimistic write":
+ // If there was nothing to send before, and there is now (predicted by the "more" value
+ // returned by the GetBytesToSend call above), attempt "optimistic write":
// because the poll/select loop may pause for SELECT_TIMEOUT_MILLISECONDS before actually
// doing a send, try sending from the calling thread if the queue was empty before.
- if (queue_was_empty) {
+ // With a V1Transport, more will always be true here, because adding a message always
+ // results in sendable bytes there, but with V2Transport this is not the case (it may
+ // still be in the handshake).
+ if (queue_was_empty && more) {
std::tie(nBytesSent, std::ignore) = SocketSendData(*pnode);
}
}
diff --git a/src/net.h b/src/net.h
index 60a15fea55..cf7a240202 100644
--- a/src/net.h
+++ b/src/net.h
@@ -6,6 +6,7 @@
#ifndef BITCOIN_NET_H
#define BITCOIN_NET_H
+#include <bip324.h>
#include <chainparams.h>
#include <common/bloom.h>
#include <compat/compat.h>
@@ -266,8 +267,6 @@ public:
/** Returns true if the current message is complete (so GetReceivedMessage can be called). */
virtual bool ReceivedMessageComplete() const = 0;
- /** Set the deserialization context version for objects returned by GetReceivedMessage. */
- virtual void SetReceiveVersion(int version) = 0;
/** Feed wire bytes to the transport.
*
@@ -300,7 +299,8 @@ public:
* - Span<const uint8_t> to_send: span of bytes to be sent over the wire (possibly empty).
* - bool more: whether there will be more bytes to be sent after the ones in to_send are
* all sent (as signaled by MarkBytesSent()).
- * - const std::string& m_type: message type on behalf of which this is being sent.
+ * - const std::string& m_type: message type on behalf of which this is being sent
+ * ("" for bytes that are not on behalf of any message).
*/
using BytesToSend = std::tuple<
Span<const uint8_t> /*to_send*/,
@@ -308,19 +308,42 @@ public:
const std::string& /*m_type*/
>;
- /** Get bytes to send on the wire.
+ /** Get bytes to send on the wire, if any, along with other information about it.
*
* As a const function, it does not modify the transport's observable state, and is thus safe
* to be called multiple times.
*
- * The bytes returned by this function act as a stream which can only be appended to. This
- * means that with the exception of MarkBytesSent, operations on the transport can only append
- * to what is being returned.
+ * @param[in] have_next_message If true, the "more" return value reports whether more will
+ * be sendable after a SetMessageToSend call. It is set by the caller when they know
+ * they have another message ready to send, and only care about what happens
+ * after that. The have_next_message argument only affects this "more" return value
+ * and nothing else.
*
- * Note that m_type and to_send refer to data that is internal to the transport, and calling
- * any non-const function on this object may invalidate them.
+ * Effectively, there are three possible outcomes about whether there are more bytes
+ * to send:
+ * - Yes: the transport itself has more bytes to send later. For example, for
+ * V1Transport this happens during the sending of the header of a
+ * message, when there is a non-empty payload that follows.
+ * - No: the transport itself has no more bytes to send, but will have bytes to
+ * send if handed a message through SetMessageToSend. In V1Transport this
+ * happens when sending the payload of a message.
+ * - Blocked: the transport itself has no more bytes to send, and is also incapable
+ * of sending anything more at all now, if it were handed another
+ * message to send. This occurs in V2Transport before the handshake is
+ * complete, as the encryption ciphers are not set up for sending
+ * messages before that point.
+ *
+ * The boolean 'more' is true for Yes, false for Blocked, and have_next_message
+ * controls what is returned for No.
+ *
+ * @return a BytesToSend object. The to_send member returned acts as a stream which is only
+ * ever appended to. This means that with the exception of MarkBytesSent (which pops
+ * bytes off the front of later to_sends), operations on the transport can only append
+ * to what is being returned. Also note that m_type and to_send refer to data that is
+ * internal to the transport, and calling any non-const function on this object may
+ * invalidate them.
*/
- virtual BytesToSend GetBytesToSend() const noexcept = 0;
+ virtual BytesToSend GetBytesToSend(bool have_next_message) const noexcept = 0;
/** Report how many bytes returned by the last GetBytesToSend() have been sent.
*
@@ -392,14 +415,6 @@ public:
return WITH_LOCK(m_recv_mutex, return CompleteInternal());
}
- void SetReceiveVersion(int nVersionIn) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
- {
- AssertLockNotHeld(m_recv_mutex);
- LOCK(m_recv_mutex);
- hdrbuf.SetVersion(nVersionIn);
- vRecv.SetVersion(nVersionIn);
- }
-
bool ReceivedBytes(Span<const uint8_t>& msg_bytes) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
{
AssertLockNotHeld(m_recv_mutex);
@@ -416,7 +431,221 @@ public:
CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex);
bool SetMessageToSend(CSerializedNetMsg& msg) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex);
- BytesToSend GetBytesToSend() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex);
+ BytesToSend GetBytesToSend(bool have_next_message) const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex);
+ void MarkBytesSent(size_t bytes_sent) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex);
+ size_t GetSendMemoryUsage() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex);
+};
+
+class V2Transport final : public Transport
+{
+private:
+ /** Contents of the version packet to send. BIP324 stipulates that senders should leave this
+ * empty, and receivers should ignore it. Future extensions can change what is sent as long as
+ * an empty version packet contents is interpreted as no extensions supported. */
+ static constexpr std::array<std::byte, 0> VERSION_CONTENTS = {};
+
+ /** The length of the V1 prefix to match bytes initially received by responders with to
+ * determine if their peer is speaking V1 or V2. */
+ static constexpr size_t V1_PREFIX_LEN = 12;
+
+ // The sender side and receiver side of V2Transport are state machines that are transitioned
+ // through, based on what has been received. The receive state corresponds to the contents of,
+ // and bytes received to, the receive buffer. The send state controls what can be appended to
+ // the send buffer and what can be sent from it.
+
+ /** State type that defines the current contents of the receive buffer and/or how the next
+ * received bytes added to it will be interpreted.
+ *
+ * Diagram:
+ *
+ * start(responder)
+ * |
+ * | start(initiator) /---------\
+ * | | | |
+ * v v v |
+ * KEY_MAYBE_V1 -> KEY -> GARB_GARBTERM -> GARBAUTH -> VERSION -> APP -> APP_READY
+ * |
+ * \-------> V1
+ */
+ enum class RecvState : uint8_t {
+ /** (Responder only) either v2 public key or v1 header.
+ *
+ * This is the initial state for responders, before data has been received to distinguish
+ * v1 from v2 connections. When that happens, the state becomes either KEY (for v2) or V1
+ * (for v1). */
+ KEY_MAYBE_V1,
+
+ /** Public key.
+ *
+ * This is the initial state for initiators, during which the other side's public key is
+ * received. When that information arrives, the ciphers get initialized and the state
+ * becomes GARB_GARBTERM. */
+ KEY,
+
+ /** Garbage and garbage terminator.
+ *
+ * Whenever a byte is received, the last 16 bytes are compared with the expected garbage
+ * terminator. When that happens, the state becomes GARBAUTH. If no matching terminator is
+ * received in 4111 bytes (4095 for the maximum garbage length, and 16 bytes for the
+ * terminator), the connection aborts. */
+ GARB_GARBTERM,
+
+ /** Garbage authentication packet.
+ *
+ * A packet is received, and decrypted/verified with AAD set to the garbage received during
+ * the GARB_GARBTERM state. If that succeeds, the state becomes VERSION. If it fails the
+ * connection aborts. */
+ GARBAUTH,
+
+ /** Version packet.
+ *
+ * A packet is received, and decrypted/verified. If that succeeds, the state becomes APP,
+ * and the decrypted contents is interpreted as version negotiation (currently, that means
+ * ignoring it, but it can be used for negotiating future extensions). If it fails, the
+ * connection aborts. */
+ VERSION,
+
+ /** Application packet.
+ *
+ * A packet is received, and decrypted/verified. If that succeeds, the state becomes
+ * APP_READY and the decrypted contents is kept in m_recv_decode_buffer until it is
+ * retrieved as a message by GetMessage(). */
+ APP,
+
+ /** Nothing (an application packet is available for GetMessage()).
+ *
+ * Nothing can be received in this state. When the message is retrieved by GetMessage,
+ * the state becomes APP again. */
+ APP_READY,
+
+ /** Nothing (this transport is using v1 fallback).
+ *
+ * All receive operations are redirected to m_v1_fallback. */
+ V1,
+ };
+
+ /** State type that controls the sender side.
+ *
+ * Diagram:
+ *
+ * start(responder)
+ * |
+ * | start(initiator)
+ * | |
+ * v v
+ * MAYBE_V1 -> AWAITING_KEY -> READY
+ * |
+ * \-----> V1
+ */
+ enum class SendState : uint8_t {
+ /** (Responder only) Not sending until v1 or v2 is detected.
+ *
+ * This is the initial state for responders. The send buffer contains the public key to
+ * send, but nothing is sent in this state yet. When the receiver determines whether this
+ * is a V1 or V2 connection, the sender state becomes AWAITING_KEY (for v2) or V1 (for v1).
+ */
+ MAYBE_V1,
+
+ /** Waiting for the other side's public key.
+ *
+ * This is the initial state for initiators. The public key is sent out. When the receiver
+ * receives the other side's public key and transitions to GARB_GARBTERM, the sender state
+ * becomes READY. */
+ AWAITING_KEY,
+
+ /** Normal sending state.
+ *
+ * In this state, the ciphers are initialized, so packets can be sent. When this state is
+ * entered, the garbage, garbage terminator, garbage authentication packet, and version
+ * packet are appended to the send buffer (in addition to the key which may still be
+ * there). In this state a message can be provided if the send buffer is empty. */
+ READY,
+
+ /** This transport is using v1 fallback.
+ *
+ * All send operations are redirected to m_v1_fallback. */
+ V1,
+ };
+
+ /** Cipher state. */
+ BIP324Cipher m_cipher;
+ /** Whether we are the initiator side. */
+ const bool m_initiating;
+ /** NodeId (for debug logging). */
+ const NodeId m_nodeid;
+ /** Encapsulate a V1Transport to fall back to. */
+ V1Transport m_v1_fallback;
+
+ /** Lock for receiver-side fields. */
+ mutable Mutex m_recv_mutex ACQUIRED_BEFORE(m_send_mutex);
+ /** In {GARBAUTH, VERSION, APP}, the decrypted packet length, if m_recv_buffer.size() >=
+ * BIP324Cipher::LENGTH_LEN. Unspecified otherwise. */
+ uint32_t m_recv_len GUARDED_BY(m_recv_mutex) {0};
+ /** Receive buffer; meaning is determined by m_recv_state. */
+ std::vector<uint8_t> m_recv_buffer GUARDED_BY(m_recv_mutex);
+ /** During GARBAUTH, the garbage received during GARB_GARBTERM. */
+ std::vector<uint8_t> m_recv_garbage GUARDED_BY(m_recv_mutex);
+ /** Buffer to put decrypted contents in, for converting to CNetMessage. */
+ std::vector<uint8_t> m_recv_decode_buffer GUARDED_BY(m_recv_mutex);
+ /** Deserialization type. */
+ const int m_recv_type;
+ /** Deserialization version number. */
+ const int m_recv_version;
+ /** Current receiver state. */
+ RecvState m_recv_state GUARDED_BY(m_recv_mutex);
+
+ /** Lock for sending-side fields. If both sending and receiving fields are accessed,
+ * m_recv_mutex must be acquired before m_send_mutex. */
+ mutable Mutex m_send_mutex ACQUIRED_AFTER(m_recv_mutex);
+ /** The send buffer; meaning is determined by m_send_state. */
+ std::vector<uint8_t> m_send_buffer GUARDED_BY(m_send_mutex);
+ /** How many bytes from the send buffer have been sent so far. */
+ uint32_t m_send_pos GUARDED_BY(m_send_mutex) {0};
+ /** Type of the message being sent. */
+ std::string m_send_type GUARDED_BY(m_send_mutex);
+ /** Current sender state. */
+ SendState m_send_state GUARDED_BY(m_send_mutex);
+
+ /** Change the receive state. */
+ void SetReceiveState(RecvState recv_state) noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
+ /** Change the send state. */
+ void SetSendState(SendState send_state) noexcept EXCLUSIVE_LOCKS_REQUIRED(m_send_mutex);
+ /** Given a packet's contents, find the message type (if valid), and strip it from contents. */
+ static std::optional<std::string> GetMessageType(Span<const uint8_t>& contents) noexcept;
+ /** Determine how many received bytes can be processed in one go (not allowed in V1 state). */
+ size_t GetMaxBytesToProcess() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
+ /** Process bytes in m_recv_buffer, while in KEY_MAYBE_V1 state. */
+ void ProcessReceivedMaybeV1Bytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex, !m_send_mutex);
+ /** Process bytes in m_recv_buffer, while in KEY state. */
+ bool ProcessReceivedKeyBytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex, !m_send_mutex);
+ /** Process bytes in m_recv_buffer, while in GARB_GARBTERM state. */
+ bool ProcessReceivedGarbageBytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
+ /** Process bytes in m_recv_buffer, while in GARBAUTH/VERSION/APP state. */
+ bool ProcessReceivedPacketBytes() noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
+
+public:
+ static constexpr uint32_t MAX_GARBAGE_LEN = 4095;
+
+ /** Construct a V2 transport with securely generated random keys.
+ *
+ * @param[in] nodeid the node's NodeId (only for debug log output).
+ * @param[in] initiating whether we are the initiator side.
+ * @param[in] type_in the serialization type of returned CNetMessages.
+ * @param[in] version_in the serialization version of returned CNetMessages.
+ */
+ V2Transport(NodeId nodeid, bool initiating, int type_in, int version_in) noexcept;
+
+ /** Construct a V2 transport with specified keys and garbage (test use only). */
+ V2Transport(NodeId nodeid, bool initiating, int type_in, int version_in, const CKey& key, Span<const std::byte> ent32, Span<const uint8_t> garbage) noexcept;
+
+ // Receive side functions.
+ bool ReceivedMessageComplete() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex);
+ bool ReceivedBytes(Span<const uint8_t>& msg_bytes) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex, !m_send_mutex);
+ CNetMessage GetReceivedMessage(std::chrono::microseconds time, bool& reject_message) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex);
+
+ // Send side functions.
+ bool SetMessageToSend(CSerializedNetMsg& msg) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex);
+ BytesToSend GetBytesToSend(bool have_next_message) const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex);
void MarkBytesSent(size_t bytes_sent) noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex);
size_t GetSendMemoryUsage() const noexcept override EXCLUSIVE_LOCKS_REQUIRED(!m_send_mutex);
};
diff --git a/src/net_processing.cpp b/src/net_processing.cpp
index c5a22f258a..6b415b3a1e 100644
--- a/src/net_processing.cpp
+++ b/src/net_processing.cpp
@@ -1415,8 +1415,8 @@ void PeerManagerImpl::PushNodeVersion(CNode& pnode, const Peer& peer)
const bool tx_relay{!RejectIncomingTxs(pnode)};
m_connman.PushMessage(&pnode, CNetMsgMaker(INIT_PROTO_VERSION).Make(NetMsgType::VERSION, PROTOCOL_VERSION, my_services, nTime,
- your_services, addr_you, // Together the pre-version-31402 serialization of CAddress "addrYou" (without nTime)
- my_services, CService(), // Together the pre-version-31402 serialization of CAddress "addrMe" (without nTime)
+ your_services, WithParams(CNetAddr::V1, addr_you), // Together the pre-version-31402 serialization of CAddress "addrYou" (without nTime)
+ my_services, WithParams(CNetAddr::V1, CService{}), // Together the pre-version-31402 serialization of CAddress "addrMe" (without nTime)
nonce, strSubVersion, nNodeStartingHeight, tx_relay));
if (fLogIPs) {
@@ -3293,7 +3293,7 @@ void PeerManagerImpl::ProcessMessage(CNode& pfrom, const std::string& msg_type,
nTime = 0;
}
vRecv.ignore(8); // Ignore the addrMe service bits sent by the peer
- vRecv >> addrMe;
+ vRecv >> WithParams(CNetAddr::V1, addrMe);
if (!pfrom.IsInboundConn())
{
m_addrman.SetServices(pfrom.addr, nServices);
@@ -3672,17 +3672,17 @@ void PeerManagerImpl::ProcessMessage(CNode& pfrom, const std::string& msg_type,
}
if (msg_type == NetMsgType::ADDR || msg_type == NetMsgType::ADDRV2) {
- int stream_version = vRecv.GetVersion();
- if (msg_type == NetMsgType::ADDRV2) {
- // Add ADDRV2_FORMAT to the version so that the CNetAddr and CAddress
+ const auto ser_params{
+ msg_type == NetMsgType::ADDRV2 ?
+ // Set V2 param so that the CNetAddr and CAddress
// unserialize methods know that an address in v2 format is coming.
- stream_version |= ADDRV2_FORMAT;
- }
+ CAddress::V2_NETWORK :
+ CAddress::V1_NETWORK,
+ };
- OverrideStream<CDataStream> s(&vRecv, vRecv.GetType(), stream_version);
std::vector<CAddress> vAddr;
- s >> vAddr;
+ vRecv >> WithParams(ser_params, vAddr);
if (!SetupAddressRelay(pfrom, *peer)) {
LogPrint(BCLog::NET, "ignoring %s message from %s peer=%d\n", msg_type, pfrom.ConnectionTypeAsString(), pfrom.GetId());
@@ -5289,15 +5289,15 @@ void PeerManagerImpl::MaybeSendAddr(CNode& node, Peer& peer, std::chrono::micros
if (peer.m_addrs_to_send.empty()) return;
const char* msg_type;
- int make_flags;
+ CNetAddr::Encoding ser_enc;
if (peer.m_wants_addrv2) {
msg_type = NetMsgType::ADDRV2;
- make_flags = ADDRV2_FORMAT;
+ ser_enc = CNetAddr::Encoding::V2;
} else {
msg_type = NetMsgType::ADDR;
- make_flags = 0;
+ ser_enc = CNetAddr::Encoding::V1;
}
- m_connman.PushMessage(&node, CNetMsgMaker(node.GetCommonVersion()).Make(make_flags, msg_type, peer.m_addrs_to_send));
+ m_connman.PushMessage(&node, CNetMsgMaker(node.GetCommonVersion()).Make(msg_type, WithParams(CAddress::SerParams{{ser_enc}, CAddress::Format::Network}, peer.m_addrs_to_send)));
peer.m_addrs_to_send.clear();
// we only send the big addr message once
diff --git a/src/netaddress.h b/src/netaddress.h
index 7cba6c00d0..a0944c886f 100644
--- a/src/netaddress.h
+++ b/src/netaddress.h
@@ -25,14 +25,6 @@
#include <vector>
/**
- * A flag that is ORed into the protocol version to designate that addresses
- * should be serialized in (unserialized from) v2 format (BIP155).
- * Make sure that this does not collide with any of the values in `version.h`
- * or with `SERIALIZE_TRANSACTION_NO_WITNESS`.
- */
-static constexpr int ADDRV2_FORMAT = 0x20000000;
-
-/**
* A network type.
* @note An address may belong to more than one network, for example `10.0.0.1`
* belongs to both `NET_UNROUTABLE` and `NET_IPV4`.
@@ -220,13 +212,23 @@ public:
return IsIPv4() || IsIPv6() || IsTor() || IsI2P() || IsCJDNS();
}
+ enum class Encoding {
+ V1,
+ V2, //!< BIP155 encoding
+ };
+ struct SerParams {
+ const Encoding enc;
+ };
+ static constexpr SerParams V1{Encoding::V1};
+ static constexpr SerParams V2{Encoding::V2};
+
/**
* Serialize to a stream.
*/
template <typename Stream>
void Serialize(Stream& s) const
{
- if (s.GetVersion() & ADDRV2_FORMAT) {
+ if (s.GetParams().enc == Encoding::V2) {
SerializeV2Stream(s);
} else {
SerializeV1Stream(s);
@@ -239,7 +241,7 @@ public:
template <typename Stream>
void Unserialize(Stream& s)
{
- if (s.GetVersion() & ADDRV2_FORMAT) {
+ if (s.GetParams().enc == Encoding::V2) {
UnserializeV2Stream(s);
} else {
UnserializeV1Stream(s);
@@ -540,8 +542,7 @@ public:
SERIALIZE_METHODS(CService, obj)
{
- READWRITEAS(CNetAddr, obj);
- READWRITE(Using<BigEndianFormatter<2>>(obj.port));
+ READWRITE(AsBase<CNetAddr>(obj), Using<BigEndianFormatter<2>>(obj.port));
}
friend class CServiceHash;
diff --git a/src/primitives/block.h b/src/primitives/block.h
index bd11279a6e..99accfc7dd 100644
--- a/src/primitives/block.h
+++ b/src/primitives/block.h
@@ -87,8 +87,7 @@ public:
SERIALIZE_METHODS(CBlock, obj)
{
- READWRITEAS(CBlockHeader, obj);
- READWRITE(obj.vtx);
+ READWRITE(AsBase<CBlockHeader>(obj), obj.vtx);
}
void SetNull()
@@ -119,6 +118,15 @@ public:
*/
struct CBlockLocator
{
+ /** Historically CBlockLocator's version field has been written to network
+ * streams as the negotiated protocol version and to disk streams as the
+ * client version, but the value has never been used.
+ *
+ * Hard-code to the highest protocol version ever written to a network stream.
+ * SerParams can be used if the field requires any meaning in the future,
+ **/
+ static constexpr int DUMMY_VERSION = 70016;
+
std::vector<uint256> vHave;
CBlockLocator() {}
@@ -127,9 +135,8 @@ struct CBlockLocator
SERIALIZE_METHODS(CBlockLocator, obj)
{
- int nVersion = s.GetVersion();
- if (!(s.GetType() & SER_GETHASH))
- READWRITE(nVersion);
+ int nVersion = DUMMY_VERSION;
+ READWRITE(nVersion);
READWRITE(obj.vHave);
}
diff --git a/src/protocol.h b/src/protocol.h
index ac4545c311..a7ca0c6f3e 100644
--- a/src/protocol.h
+++ b/src/protocol.h
@@ -390,35 +390,43 @@ public:
CAddress(CService ipIn, ServiceFlags nServicesIn) : CService{ipIn}, nServices{nServicesIn} {};
CAddress(CService ipIn, ServiceFlags nServicesIn, NodeSeconds time) : CService{ipIn}, nTime{time}, nServices{nServicesIn} {};
- SERIALIZE_METHODS(CAddress, obj)
+ enum class Format {
+ Disk,
+ Network,
+ };
+ struct SerParams : CNetAddr::SerParams {
+ const Format fmt;
+ };
+ static constexpr SerParams V1_NETWORK{{CNetAddr::Encoding::V1}, Format::Network};
+ static constexpr SerParams V2_NETWORK{{CNetAddr::Encoding::V2}, Format::Network};
+ static constexpr SerParams V1_DISK{{CNetAddr::Encoding::V1}, Format::Disk};
+ static constexpr SerParams V2_DISK{{CNetAddr::Encoding::V2}, Format::Disk};
+
+ SERIALIZE_METHODS_PARAMS(CAddress, obj, SerParams, params)
{
- // CAddress has a distinct network serialization and a disk serialization, but it should never
- // be hashed (except through CHashWriter in addrdb.cpp, which sets SER_DISK), and it's
- // ambiguous what that would mean. Make sure no code relying on that is introduced:
- assert(!(s.GetType() & SER_GETHASH));
bool use_v2;
- if (s.GetType() & SER_DISK) {
+ if (params.fmt == Format::Disk) {
// In the disk serialization format, the encoding (v1 or v2) is determined by a flag version
// that's part of the serialization itself. ADDRV2_FORMAT in the stream version only determines
// whether V2 is chosen/permitted at all.
uint32_t stored_format_version = DISK_VERSION_INIT;
- if (s.GetVersion() & ADDRV2_FORMAT) stored_format_version |= DISK_VERSION_ADDRV2;
+ if (params.enc == Encoding::V2) stored_format_version |= DISK_VERSION_ADDRV2;
READWRITE(stored_format_version);
stored_format_version &= ~DISK_VERSION_IGNORE_MASK; // ignore low bits
if (stored_format_version == 0) {
use_v2 = false;
- } else if (stored_format_version == DISK_VERSION_ADDRV2 && (s.GetVersion() & ADDRV2_FORMAT)) {
- // Only support v2 deserialization if ADDRV2_FORMAT is set.
+ } else if (stored_format_version == DISK_VERSION_ADDRV2 && params.enc == Encoding::V2) {
+ // Only support v2 deserialization if V2 is set.
use_v2 = true;
} else {
throw std::ios_base::failure("Unsupported CAddress disk format version");
}
} else {
+ assert(params.fmt == Format::Network);
// In the network serialization format, the encoding (v1 or v2) is determined directly by
- // the value of ADDRV2_FORMAT in the stream version, as no explicitly encoded version
+ // the value of enc in the stream params, as no explicitly encoded version
// exists in the stream.
- assert(s.GetType() & SER_NETWORK);
- use_v2 = s.GetVersion() & ADDRV2_FORMAT;
+ use_v2 = params.enc == Encoding::V2;
}
READWRITE(Using<LossyChronoFormatter<uint32_t>>(obj.nTime));
@@ -432,8 +440,8 @@ public:
READWRITE(Using<CustomUintFormatter<8>>(obj.nServices));
}
// Invoke V1/V2 serializer for CService parent object.
- OverrideStream<Stream> os(&s, s.GetType(), use_v2 ? ADDRV2_FORMAT : 0);
- SerReadWriteMany(os, ser_action, ReadWriteAsHelper<CService>(obj));
+ const auto ser_params{use_v2 ? CNetAddr::V2 : CNetAddr::V1};
+ READWRITE(WithParams(ser_params, AsBase<CService>(obj)));
}
//! Always included in serialization. The behavior is unspecified if the value is not representable as uint32_t.
diff --git a/src/pubkey.cpp b/src/pubkey.cpp
index 4866feed67..05808e4c22 100644
--- a/src/pubkey.cpp
+++ b/src/pubkey.cpp
@@ -336,6 +336,12 @@ bool CPubKey::Derive(CPubKey& pubkeyChild, ChainCode &ccChild, unsigned int nChi
return true;
}
+EllSwiftPubKey::EllSwiftPubKey(Span<const std::byte> ellswift) noexcept
+{
+ assert(ellswift.size() == SIZE);
+ std::copy(ellswift.begin(), ellswift.end(), m_pubkey.begin());
+}
+
CPubKey EllSwiftPubKey::Decode() const
{
secp256k1_pubkey pubkey;
diff --git a/src/pubkey.h b/src/pubkey.h
index 00defa25a0..274779f9a4 100644
--- a/src/pubkey.h
+++ b/src/pubkey.h
@@ -303,8 +303,7 @@ public:
EllSwiftPubKey() noexcept = default;
/** Construct a new ellswift public key from a given serialization. */
- EllSwiftPubKey(const std::array<std::byte, SIZE>& ellswift) :
- m_pubkey(ellswift) {}
+ EllSwiftPubKey(Span<const std::byte> ellswift) noexcept;
/** Decode to normal compressed CPubKey (for debugging purposes). */
CPubKey Decode() const;
diff --git a/src/script/script.h b/src/script/script.h
index 902f756afc..c329a2afd6 100644
--- a/src/script/script.h
+++ b/src/script/script.h
@@ -434,7 +434,7 @@ public:
CScript(std::vector<unsigned char>::const_iterator pbegin, std::vector<unsigned char>::const_iterator pend) : CScriptBase(pbegin, pend) { }
CScript(const unsigned char* pbegin, const unsigned char* pend) : CScriptBase(pbegin, pend) { }
- SERIALIZE_METHODS(CScript, obj) { READWRITEAS(CScriptBase, obj); }
+ SERIALIZE_METHODS(CScript, obj) { READWRITE(AsBase<CScriptBase>(obj)); }
explicit CScript(int64_t b) { operator<<(b); }
explicit CScript(opcodetype b) { operator<<(b); }
diff --git a/src/serialize.h b/src/serialize.h
index 39f2c0f3ae..2d790190a0 100644
--- a/src/serialize.h
+++ b/src/serialize.h
@@ -6,6 +6,7 @@
#ifndef BITCOIN_SERIALIZE_H
#define BITCOIN_SERIALIZE_H
+#include <attributes.h>
#include <compat/endian.h>
#include <algorithm>
@@ -133,12 +134,40 @@ enum
SER_GETHASH = (1 << 2),
};
-//! Convert the reference base type to X, without changing constness or reference type.
-template<typename X> X& ReadWriteAsHelper(X& x) { return x; }
-template<typename X> const X& ReadWriteAsHelper(const X& x) { return x; }
+/**
+ * Convert any argument to a reference to X, maintaining constness.
+ *
+ * This can be used in serialization code to invoke a base class's
+ * serialization routines.
+ *
+ * Example use:
+ * class Base { ... };
+ * class Child : public Base {
+ * int m_data;
+ * public:
+ * SERIALIZE_METHODS(Child, obj) {
+ * READWRITE(AsBase<Base>(obj), obj.m_data);
+ * }
+ * };
+ *
+ * static_cast cannot easily be used here, as the type of Obj will be const Child&
+ * during serialization and Child& during deserialization. AsBase will convert to
+ * const Base& and Base& appropriately.
+ */
+template <class Out, class In>
+Out& AsBase(In& x)
+{
+ static_assert(std::is_base_of_v<Out, In>);
+ return x;
+}
+template <class Out, class In>
+const Out& AsBase(const In& x)
+{
+ static_assert(std::is_base_of_v<Out, In>);
+ return x;
+}
#define READWRITE(...) (::SerReadWriteMany(s, ser_action, __VA_ARGS__))
-#define READWRITEAS(type, obj) (::SerReadWriteMany(s, ser_action, ReadWriteAsHelper<type>(obj)))
#define SER_READ(obj, code) ::SerRead(s, ser_action, obj, [&](Stream& s, typename std::remove_const<Type>::type& obj) { code; })
#define SER_WRITE(obj, code) ::SerWrite(s, ser_action, obj, [&](Stream& s, const Type& obj) { code; })
@@ -160,11 +189,66 @@ template<typename X> const X& ReadWriteAsHelper(const X& x) { return x; }
*/
#define FORMATTER_METHODS(cls, obj) \
template<typename Stream> \
- static void Ser(Stream& s, const cls& obj) { SerializationOps(obj, s, CSerActionSerialize()); } \
+ static void Ser(Stream& s, const cls& obj) { SerializationOps(obj, s, ActionSerialize{}); } \
template<typename Stream> \
- static void Unser(Stream& s, cls& obj) { SerializationOps(obj, s, CSerActionUnserialize()); } \
+ static void Unser(Stream& s, cls& obj) { SerializationOps(obj, s, ActionUnserialize{}); } \
template<typename Stream, typename Type, typename Operation> \
- static inline void SerializationOps(Type& obj, Stream& s, Operation ser_action) \
+ static void SerializationOps(Type& obj, Stream& s, Operation ser_action)
+
+/**
+ * Variant of FORMATTER_METHODS that supports a declared parameter type.
+ *
+ * If a formatter has a declared parameter type, it must be invoked directly or
+ * indirectly with a parameter of that type. This permits making serialization
+ * depend on run-time context in a type-safe way.
+ *
+ * Example use:
+ * struct BarParameter { bool fancy; ... };
+ * struct Bar { ... };
+ * struct FooFormatter {
+ * FORMATTER_METHODS(Bar, obj, BarParameter, param) {
+ * if (param.fancy) {
+ * READWRITE(VARINT(obj.value));
+ * } else {
+ * READWRITE(obj.value);
+ * }
+ * }
+ * };
+ * which would then be invoked as
+ * READWRITE(WithParams(BarParameter{...}, Using<FooFormatter>(obj.foo)))
+ *
+ * WithParams(parameter, obj) can be invoked anywhere in the call stack; it is
+ * passed down recursively into all serialization code, until another
+ * WithParams overrides it.
+ *
+ * Parameters will be implicitly converted where appropriate. This means that
+ * "parent" serialization code can use a parameter that derives from, or is
+ * convertible to, a "child" formatter's parameter type.
+ *
+ * Compilation will fail in any context where serialization is invoked but
+ * no parameter of a type convertible to BarParameter is provided.
+ */
+#define FORMATTER_METHODS_PARAMS(cls, obj, paramcls, paramobj) \
+ template <typename Stream> \
+ static void Ser(Stream& s, const cls& obj) { SerializationOps(obj, s, ActionSerialize{}, s.GetParams()); } \
+ template <typename Stream> \
+ static void Unser(Stream& s, cls& obj) { SerializationOps(obj, s, ActionUnserialize{}, s.GetParams()); } \
+ template <typename Stream, typename Type, typename Operation> \
+ static void SerializationOps(Type& obj, Stream& s, Operation ser_action, const paramcls& paramobj)
+
+#define BASE_SERIALIZE_METHODS(cls) \
+ template <typename Stream> \
+ void Serialize(Stream& s) const \
+ { \
+ static_assert(std::is_same<const cls&, decltype(*this)>::value, "Serialize type mismatch"); \
+ Ser(s, *this); \
+ } \
+ template <typename Stream> \
+ void Unserialize(Stream& s) \
+ { \
+ static_assert(std::is_same<cls&, decltype(*this)>::value, "Unserialize type mismatch"); \
+ Unser(s, *this); \
+ }
/**
* Implement the Serialize and Unserialize methods by delegating to a single templated
@@ -173,21 +257,19 @@ template<typename X> const X& ReadWriteAsHelper(const X& x) { return x; }
* thus allows a single implementation that sees the object as const for serializing
* and non-const for deserializing, without casts.
*/
-#define SERIALIZE_METHODS(cls, obj) \
- template<typename Stream> \
- void Serialize(Stream& s) const \
- { \
- static_assert(std::is_same<const cls&, decltype(*this)>::value, "Serialize type mismatch"); \
- Ser(s, *this); \
- } \
- template<typename Stream> \
- void Unserialize(Stream& s) \
- { \
- static_assert(std::is_same<cls&, decltype(*this)>::value, "Unserialize type mismatch"); \
- Unser(s, *this); \
- } \
+#define SERIALIZE_METHODS(cls, obj) \
+ BASE_SERIALIZE_METHODS(cls) \
FORMATTER_METHODS(cls, obj)
+/**
+ * Variant of SERIALIZE_METHODS that supports a declared parameter type.
+ *
+ * See FORMATTER_METHODS_PARAMS for more information on parameters.
+ */
+#define SERIALIZE_METHODS_PARAMS(cls, obj, paramcls, paramobj) \
+ BASE_SERIALIZE_METHODS(cls) \
+ FORMATTER_METHODS_PARAMS(cls, obj, paramcls, paramobj)
+
// clang-format off
#ifndef CHAR_EQUALS_INT8
template <typename Stream> void Serialize(Stream&, char) = delete; // char serialization forbidden. Use uint8_t or int8_t
@@ -925,26 +1007,17 @@ void Unserialize(Stream& is, std::shared_ptr<const T>& p)
}
-
/**
- * Support for SERIALIZE_METHODS and READWRITE macro.
+ * Support for all macros providing or using the ser_action parameter of the SerializationOps method.
*/
-struct CSerActionSerialize
-{
+struct ActionSerialize {
constexpr bool ForRead() const { return false; }
};
-struct CSerActionUnserialize
-{
+struct ActionUnserialize {
constexpr bool ForRead() const { return true; }
};
-
-
-
-
-
-
/* ::GetSerializeSize implementations
*
* Computing the serialized size of objects is done through a special stream
@@ -1003,36 +1076,36 @@ inline void UnserializeMany(Stream& s, Args&&... args)
}
template<typename Stream, typename... Args>
-inline void SerReadWriteMany(Stream& s, CSerActionSerialize ser_action, const Args&... args)
+inline void SerReadWriteMany(Stream& s, ActionSerialize ser_action, const Args&... args)
{
::SerializeMany(s, args...);
}
template<typename Stream, typename... Args>
-inline void SerReadWriteMany(Stream& s, CSerActionUnserialize ser_action, Args&&... args)
+inline void SerReadWriteMany(Stream& s, ActionUnserialize ser_action, Args&&... args)
{
::UnserializeMany(s, args...);
}
template<typename Stream, typename Type, typename Fn>
-inline void SerRead(Stream& s, CSerActionSerialize ser_action, Type&&, Fn&&)
+inline void SerRead(Stream& s, ActionSerialize ser_action, Type&&, Fn&&)
{
}
template<typename Stream, typename Type, typename Fn>
-inline void SerRead(Stream& s, CSerActionUnserialize ser_action, Type&& obj, Fn&& fn)
+inline void SerRead(Stream& s, ActionUnserialize ser_action, Type&& obj, Fn&& fn)
{
fn(s, std::forward<Type>(obj));
}
template<typename Stream, typename Type, typename Fn>
-inline void SerWrite(Stream& s, CSerActionSerialize ser_action, Type&& obj, Fn&& fn)
+inline void SerWrite(Stream& s, ActionSerialize ser_action, Type&& obj, Fn&& fn)
{
fn(s, std::forward<Type>(obj));
}
template<typename Stream, typename Type, typename Fn>
-inline void SerWrite(Stream& s, CSerActionUnserialize ser_action, Type&&, Fn&&)
+inline void SerWrite(Stream& s, ActionUnserialize ser_action, Type&&, Fn&&)
{
}
@@ -1061,4 +1134,61 @@ size_t GetSerializeSizeMany(int nVersion, const T&... t)
return sc.size();
}
+/** Wrapper that overrides the GetParams() function of a stream (and hides GetVersion/GetType). */
+template <typename Params, typename SubStream>
+class ParamsStream
+{
+ const Params& m_params;
+ SubStream& m_substream; // private to avoid leaking version/type into serialization code that shouldn't see it
+
+public:
+ ParamsStream(const Params& params LIFETIMEBOUND, SubStream& substream LIFETIMEBOUND) : m_params{params}, m_substream{substream} {}
+ template <typename U> ParamsStream& operator<<(const U& obj) { ::Serialize(*this, obj); return *this; }
+ template <typename U> ParamsStream& operator>>(U&& obj) { ::Unserialize(*this, obj); return *this; }
+ void write(Span<const std::byte> src) { m_substream.write(src); }
+ void read(Span<std::byte> dst) { m_substream.read(dst); }
+ void ignore(size_t num) { m_substream.ignore(num); }
+ bool eof() const { return m_substream.eof(); }
+ size_t size() const { return m_substream.size(); }
+ const Params& GetParams() const { return m_params; }
+ int GetVersion() = delete; // Deprecated with Params usage
+ int GetType() = delete; // Deprecated with Params usage
+};
+
+/** Wrapper that serializes objects with the specified parameters. */
+template <typename Params, typename T>
+class ParamsWrapper
+{
+ static_assert(std::is_lvalue_reference<T>::value, "ParamsWrapper needs an lvalue reference type T");
+ const Params& m_params;
+ T m_object;
+
+public:
+ explicit ParamsWrapper(const Params& params, T obj) : m_params{params}, m_object{obj} {}
+
+ template <typename Stream>
+ void Serialize(Stream& s) const
+ {
+ ParamsStream ss{m_params, s};
+ ::Serialize(ss, m_object);
+ }
+ template <typename Stream>
+ void Unserialize(Stream& s)
+ {
+ ParamsStream ss{m_params, s};
+ ::Unserialize(ss, m_object);
+ }
+};
+
+/**
+ * Return a wrapper around t that (de)serializes it with specified parameter params.
+ *
+ * See FORMATTER_METHODS_PARAMS for more information on serialization parameters.
+ */
+template <typename Params, typename T>
+static auto WithParams(const Params& params, T&& t)
+{
+ return ParamsWrapper<Params, T&>{params, t};
+}
+
#endif // BITCOIN_SERIALIZE_H
diff --git a/src/test/addrman_tests.cpp b/src/test/addrman_tests.cpp
index 329b89554d..941018a820 100644
--- a/src/test/addrman_tests.cpp
+++ b/src/test/addrman_tests.cpp
@@ -697,7 +697,7 @@ BOOST_AUTO_TEST_CASE(addrman_serialization)
auto addrman_asmap1_dup = std::make_unique<AddrMan>(netgroupman, DETERMINISTIC, ratio);
auto addrman_noasmap = std::make_unique<AddrMan>(EMPTY_NETGROUPMAN, DETERMINISTIC, ratio);
- CDataStream stream(SER_NETWORK, PROTOCOL_VERSION);
+ DataStream stream{};
CAddress addr = CAddress(ResolveService("250.1.1.1"), NODE_NONE);
CNetAddr default_source;
@@ -757,7 +757,7 @@ BOOST_AUTO_TEST_CASE(remove_invalid)
// Confirm that invalid addresses are ignored in unserialization.
auto addrman = std::make_unique<AddrMan>(EMPTY_NETGROUPMAN, DETERMINISTIC, GetCheckRatio(m_node));
- CDataStream stream(SER_NETWORK, PROTOCOL_VERSION);
+ DataStream stream{};
const CAddress new1{ResolveService("5.5.5.5"), NODE_NONE};
const CAddress new2{ResolveService("6.6.6.6"), NODE_NONE};
@@ -940,9 +940,9 @@ BOOST_AUTO_TEST_CASE(addrman_evictionworks)
BOOST_CHECK(!addr_pos36.tried);
}
-static CDataStream AddrmanToStream(const AddrMan& addrman)
+static auto AddrmanToStream(const AddrMan& addrman)
{
- CDataStream ssPeersIn(SER_DISK, CLIENT_VERSION);
+ DataStream ssPeersIn{};
ssPeersIn << Params().MessageStart();
ssPeersIn << addrman;
return ssPeersIn;
@@ -972,7 +972,7 @@ BOOST_AUTO_TEST_CASE(load_addrman)
BOOST_CHECK(addrman.Size() == 3);
// Test that the de-serialization does not throw an exception.
- CDataStream ssPeers1 = AddrmanToStream(addrman);
+ auto ssPeers1{AddrmanToStream(addrman)};
bool exceptionThrown = false;
AddrMan addrman1{EMPTY_NETGROUPMAN, !DETERMINISTIC, GetCheckRatio(m_node)};
@@ -989,7 +989,7 @@ BOOST_AUTO_TEST_CASE(load_addrman)
BOOST_CHECK(exceptionThrown == false);
// Test that ReadFromStream creates an addrman with the correct number of addrs.
- CDataStream ssPeers2 = AddrmanToStream(addrman);
+ DataStream ssPeers2 = AddrmanToStream(addrman);
AddrMan addrman2{EMPTY_NETGROUPMAN, !DETERMINISTIC, GetCheckRatio(m_node)};
BOOST_CHECK(addrman2.Size() == 0);
@@ -998,9 +998,9 @@ BOOST_AUTO_TEST_CASE(load_addrman)
}
// Produce a corrupt peers.dat that claims 20 addrs when it only has one addr.
-static CDataStream MakeCorruptPeersDat()
+static auto MakeCorruptPeersDat()
{
- CDataStream s(SER_DISK, CLIENT_VERSION);
+ DataStream s{};
s << ::Params().MessageStart();
unsigned char nVersion = 1;
@@ -1019,7 +1019,7 @@ static CDataStream MakeCorruptPeersDat()
std::optional<CNetAddr> resolved{LookupHost("252.2.2.2", false)};
BOOST_REQUIRE(resolved.has_value());
AddrInfo info = AddrInfo(addr, resolved.value());
- s << info;
+ s << WithParams(CAddress::V1_DISK, info);
return s;
}
@@ -1027,7 +1027,7 @@ static CDataStream MakeCorruptPeersDat()
BOOST_AUTO_TEST_CASE(load_addrman_corrupted)
{
// Test that the de-serialization of corrupted peers.dat throws an exception.
- CDataStream ssPeers1 = MakeCorruptPeersDat();
+ auto ssPeers1{MakeCorruptPeersDat()};
bool exceptionThrown = false;
AddrMan addrman1{EMPTY_NETGROUPMAN, !DETERMINISTIC, GetCheckRatio(m_node)};
BOOST_CHECK(addrman1.Size() == 0);
@@ -1041,7 +1041,7 @@ BOOST_AUTO_TEST_CASE(load_addrman_corrupted)
BOOST_CHECK(exceptionThrown);
// Test that ReadFromStream fails if peers.dat is corrupt
- CDataStream ssPeers2 = MakeCorruptPeersDat();
+ auto ssPeers2{MakeCorruptPeersDat()};
AddrMan addrman2{EMPTY_NETGROUPMAN, !DETERMINISTIC, GetCheckRatio(m_node)};
BOOST_CHECK(addrman2.Size() == 0);
diff --git a/src/test/bip324_tests.cpp b/src/test/bip324_tests.cpp
index 04472611ec..1ed7e23bcf 100644
--- a/src/test/bip324_tests.cpp
+++ b/src/test/bip324_tests.cpp
@@ -38,14 +38,8 @@ void TestBIP324PacketVector(
{
// Convert input from hex to char/byte vectors/arrays.
const auto in_priv_ours = ParseHex(in_priv_ours_hex);
- const auto in_ellswift_ours_vec = ParseHex<std::byte>(in_ellswift_ours_hex);
- assert(in_ellswift_ours_vec.size() == 64);
- std::array<std::byte, 64> in_ellswift_ours;
- std::copy(in_ellswift_ours_vec.begin(), in_ellswift_ours_vec.end(), in_ellswift_ours.begin());
- const auto in_ellswift_theirs_vec = ParseHex<std::byte>(in_ellswift_theirs_hex);
- assert(in_ellswift_theirs_vec.size() == 64);
- std::array<std::byte, 64> in_ellswift_theirs;
- std::copy(in_ellswift_theirs_vec.begin(), in_ellswift_theirs_vec.end(), in_ellswift_theirs.begin());
+ const auto in_ellswift_ours = ParseHex<std::byte>(in_ellswift_ours_hex);
+ const auto in_ellswift_theirs = ParseHex<std::byte>(in_ellswift_theirs_hex);
const auto in_contents = ParseHex<std::byte>(in_contents_hex);
const auto in_aad = ParseHex<std::byte>(in_aad_hex);
const auto mid_send_garbage = ParseHex<std::byte>(mid_send_garbage_hex);
diff --git a/src/test/blockmanager_tests.cpp b/src/test/blockmanager_tests.cpp
index 553bb31ba1..13cb1cc314 100644
--- a/src/test/blockmanager_tests.cpp
+++ b/src/test/blockmanager_tests.cpp
@@ -3,6 +3,7 @@
// file COPYING or http://www.opensource.org/licenses/mit-license.php.
#include <chainparams.h>
+#include <clientversion.h>
#include <node/blockstorage.h>
#include <node/context.h>
#include <node/kernel_notifications.h>
diff --git a/src/test/denialofservice_tests.cpp b/src/test/denialofservice_tests.cpp
index 7f5d587cf6..8c1182b5e1 100644
--- a/src/test/denialofservice_tests.cpp
+++ b/src/test/denialofservice_tests.cpp
@@ -86,7 +86,7 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction)
{
LOCK(dummyNode1.cs_vSend);
- const auto& [to_send, _more, _msg_type] = dummyNode1.m_transport->GetBytesToSend();
+ const auto& [to_send, _more, _msg_type] = dummyNode1.m_transport->GetBytesToSend(false);
BOOST_CHECK(!to_send.empty());
}
connman.FlushSendBuffer(dummyNode1);
@@ -97,7 +97,7 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction)
BOOST_CHECK(peerman.SendMessages(&dummyNode1)); // should result in getheaders
{
LOCK(dummyNode1.cs_vSend);
- const auto& [to_send, _more, _msg_type] = dummyNode1.m_transport->GetBytesToSend();
+ const auto& [to_send, _more, _msg_type] = dummyNode1.m_transport->GetBytesToSend(false);
BOOST_CHECK(!to_send.empty());
}
// Wait 3 more minutes
diff --git a/src/test/fuzz/addrman.cpp b/src/test/fuzz/addrman.cpp
index 02df4590de..9611a872ec 100644
--- a/src/test/fuzz/addrman.cpp
+++ b/src/test/fuzz/addrman.cpp
@@ -49,7 +49,7 @@ void initialize_addrman()
FUZZ_TARGET(data_stream_addr_man, .init = initialize_addrman)
{
FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()};
- CDataStream data_stream = ConsumeDataStream(fuzzed_data_provider);
+ DataStream data_stream = ConsumeDataStream(fuzzed_data_provider);
NetGroupManager netgroupman{ConsumeNetGroupManager(fuzzed_data_provider)};
AddrMan addr_man(netgroupman, /*deterministic=*/false, GetCheckRatio());
try {
@@ -78,12 +78,12 @@ CNetAddr RandAddr(FuzzedDataProvider& fuzzed_data_provider, FastRandomContext& f
net = 6;
}
- CDataStream s(SER_NETWORK, PROTOCOL_VERSION | ADDRV2_FORMAT);
+ DataStream s{};
s << net;
s << fast_random_context.randbytes(net_len_map.at(net));
- s >> addr;
+ s >> WithParams(CAddress::V2_NETWORK, addr);
}
// Return a dummy IPv4 5.5.5.5 if we generated an invalid address.
@@ -241,9 +241,7 @@ FUZZ_TARGET(addrman, .init = initialize_addrman)
auto addr_man_ptr = std::make_unique<AddrManDeterministic>(netgroupman, fuzzed_data_provider);
if (fuzzed_data_provider.ConsumeBool()) {
const std::vector<uint8_t> serialized_data{ConsumeRandomLengthByteVector(fuzzed_data_provider)};
- CDataStream ds(serialized_data, SER_DISK, INIT_PROTO_VERSION);
- const auto ser_version{fuzzed_data_provider.ConsumeIntegral<int32_t>()};
- ds.SetVersion(ser_version);
+ DataStream ds{serialized_data};
try {
ds >> *addr_man_ptr;
} catch (const std::ios_base::failure&) {
@@ -295,7 +293,7 @@ FUZZ_TARGET(addrman, .init = initialize_addrman)
in_new = fuzzed_data_provider.ConsumeBool();
}
(void)const_addr_man.Size(network, in_new);
- CDataStream data_stream(SER_NETWORK, PROTOCOL_VERSION);
+ DataStream data_stream{};
data_stream << const_addr_man;
}
@@ -309,10 +307,10 @@ FUZZ_TARGET(addrman_serdeser, .init = initialize_addrman)
AddrManDeterministic addr_man1{netgroupman, fuzzed_data_provider};
AddrManDeterministic addr_man2{netgroupman, fuzzed_data_provider};
- CDataStream data_stream(SER_NETWORK, PROTOCOL_VERSION);
+ DataStream data_stream{};
FillAddrman(addr_man1, fuzzed_data_provider);
data_stream << addr_man1;
data_stream >> addr_man2;
assert(addr_man1 == addr_man2);
-} \ No newline at end of file
+}
diff --git a/src/test/fuzz/bip324.cpp b/src/test/fuzz/bip324.cpp
index 98ac10e364..e5ed9bfd52 100644
--- a/src/test/fuzz/bip324.cpp
+++ b/src/test/fuzz/bip324.cpp
@@ -30,19 +30,13 @@ FUZZ_TARGET(bip324_cipher_roundtrip, .init=Initialize)
// Load keys from fuzzer.
FuzzedDataProvider provider(buffer.data(), buffer.size());
// Initiator key
- auto init_key_data = provider.ConsumeBytes<unsigned char>(32);
- init_key_data.resize(32);
- CKey init_key;
- init_key.Set(init_key_data.begin(), init_key_data.end(), true);
+ CKey init_key = ConsumePrivateKey(provider, /*compressed=*/true);
if (!init_key.IsValid()) return;
// Initiator entropy
auto init_ent = provider.ConsumeBytes<std::byte>(32);
init_ent.resize(32);
// Responder key
- auto resp_key_data = provider.ConsumeBytes<unsigned char>(32);
- resp_key_data.resize(32);
- CKey resp_key;
- resp_key.Set(resp_key_data.begin(), resp_key_data.end(), true);
+ CKey resp_key = ConsumePrivateKey(provider, /*compressed=*/true);
if (!resp_key.IsValid()) return;
// Responder entropy
auto resp_ent = provider.ConsumeBytes<std::byte>(32);
diff --git a/src/test/fuzz/deserialize.cpp b/src/test/fuzz/deserialize.cpp
index 09402233bd..100a6b4ee4 100644
--- a/src/test/fuzz/deserialize.cpp
+++ b/src/test/fuzz/deserialize.cpp
@@ -24,6 +24,8 @@
#include <pubkey.h>
#include <script/keyorigin.h>
#include <streams.h>
+#include <test/fuzz/fuzz.h>
+#include <test/fuzz/util.h>
#include <test/util/setup_common.h>
#include <undo.h>
#include <version.h>
@@ -34,8 +36,6 @@
#include <stdint.h>
#include <unistd.h>
-#include <test/fuzz/fuzz.h>
-
using node::SnapshotMetadata;
namespace {
@@ -62,6 +62,34 @@ namespace {
struct invalid_fuzzing_input_exception : public std::exception {
};
+template <typename T, typename P>
+DataStream Serialize(const T& obj, const P& params)
+{
+ DataStream ds{};
+ ds << WithParams(params, obj);
+ return ds;
+}
+
+template <typename T, typename P>
+T Deserialize(DataStream&& ds, const P& params)
+{
+ T obj;
+ ds >> WithParams(params, obj);
+ return obj;
+}
+
+template <typename T, typename P>
+void DeserializeFromFuzzingInput(FuzzBufferType buffer, T&& obj, const P& params)
+{
+ DataStream ds{buffer};
+ try {
+ ds >> WithParams(params, obj);
+ } catch (const std::ios_base::failure&) {
+ throw invalid_fuzzing_input_exception();
+ }
+ assert(buffer.empty() || !Serialize(obj, params).empty());
+}
+
template <typename T>
CDataStream Serialize(const T& obj, const int version = INIT_PROTO_VERSION, const int ser_type = SER_NETWORK)
{
@@ -79,7 +107,7 @@ T Deserialize(CDataStream ds)
}
template <typename T>
-void DeserializeFromFuzzingInput(FuzzBufferType buffer, T& obj, const std::optional<int> protocol_version = std::nullopt, const int ser_type = SER_NETWORK)
+void DeserializeFromFuzzingInput(FuzzBufferType buffer, T&& obj, const std::optional<int> protocol_version = std::nullopt, const int ser_type = SER_NETWORK)
{
CDataStream ds(buffer, ser_type, INIT_PROTO_VERSION);
if (protocol_version) {
@@ -101,6 +129,11 @@ void DeserializeFromFuzzingInput(FuzzBufferType buffer, T& obj, const std::optio
assert(buffer.empty() || !Serialize(obj).empty());
}
+template <typename T, typename P>
+void AssertEqualAfterSerializeDeserialize(const T& obj, const P& params)
+{
+ assert(Deserialize<T>(Serialize(obj, params), params) == obj);
+}
template <typename T>
void AssertEqualAfterSerializeDeserialize(const T& obj, const int version = INIT_PROTO_VERSION, const int ser_type = SER_NETWORK)
{
@@ -113,10 +146,11 @@ FUZZ_TARGET_DESERIALIZE(block_filter_deserialize, {
BlockFilter block_filter;
DeserializeFromFuzzingInput(buffer, block_filter);
})
-FUZZ_TARGET_DESERIALIZE(addr_info_deserialize, {
- AddrInfo addr_info;
- DeserializeFromFuzzingInput(buffer, addr_info);
-})
+FUZZ_TARGET(addr_info_deserialize, .init = initialize_deserialize)
+{
+ FuzzedDataProvider fdp{buffer.data(), buffer.size()};
+ (void)ConsumeDeserializable<AddrInfo>(fdp, ConsumeDeserializationParams<CAddress::SerParams>(fdp));
+}
FUZZ_TARGET_DESERIALIZE(block_file_info_deserialize, {
CBlockFileInfo block_file_info;
DeserializeFromFuzzingInput(buffer, block_file_info);
@@ -197,13 +231,6 @@ FUZZ_TARGET_DESERIALIZE(blockmerkleroot, {
bool mutated;
BlockMerkleRoot(block, &mutated);
})
-FUZZ_TARGET_DESERIALIZE(addrman_deserialize, {
- NetGroupManager netgroupman{std::vector<bool>()};
- AddrMan am(netgroupman,
- /*deterministic=*/false,
- g_setup->m_node.args->GetIntArg("-checkaddrman", 0));
- DeserializeFromFuzzingInput(buffer, am);
-})
FUZZ_TARGET_DESERIALIZE(blockheader_deserialize, {
CBlockHeader bh;
DeserializeFromFuzzingInput(buffer, bh);
@@ -220,66 +247,62 @@ FUZZ_TARGET_DESERIALIZE(coins_deserialize, {
Coin coin;
DeserializeFromFuzzingInput(buffer, coin);
})
-FUZZ_TARGET_DESERIALIZE(netaddr_deserialize, {
- CNetAddr na;
- DeserializeFromFuzzingInput(buffer, na);
+FUZZ_TARGET(netaddr_deserialize, .init = initialize_deserialize)
+{
+ FuzzedDataProvider fdp{buffer.data(), buffer.size()};
+ const auto maybe_na{ConsumeDeserializable<CNetAddr>(fdp, ConsumeDeserializationParams<CNetAddr::SerParams>(fdp))};
+ if (!maybe_na) return;
+ const CNetAddr& na{*maybe_na};
if (na.IsAddrV1Compatible()) {
- AssertEqualAfterSerializeDeserialize(na);
+ AssertEqualAfterSerializeDeserialize(na, ConsumeDeserializationParams<CNetAddr::SerParams>(fdp));
}
- AssertEqualAfterSerializeDeserialize(na, INIT_PROTO_VERSION | ADDRV2_FORMAT);
-})
-FUZZ_TARGET_DESERIALIZE(service_deserialize, {
- CService s;
- DeserializeFromFuzzingInput(buffer, s);
+ AssertEqualAfterSerializeDeserialize(na, CNetAddr::V2);
+}
+FUZZ_TARGET(service_deserialize, .init = initialize_deserialize)
+{
+ FuzzedDataProvider fdp{buffer.data(), buffer.size()};
+ const auto ser_params{ConsumeDeserializationParams<CNetAddr::SerParams>(fdp)};
+ const auto maybe_s{ConsumeDeserializable<CService>(fdp, ser_params)};
+ if (!maybe_s) return;
+ const CService& s{*maybe_s};
if (s.IsAddrV1Compatible()) {
- AssertEqualAfterSerializeDeserialize(s);
+ AssertEqualAfterSerializeDeserialize(s, ConsumeDeserializationParams<CNetAddr::SerParams>(fdp));
}
- AssertEqualAfterSerializeDeserialize(s, INIT_PROTO_VERSION | ADDRV2_FORMAT);
- CService s1;
- DeserializeFromFuzzingInput(buffer, s1, INIT_PROTO_VERSION);
- AssertEqualAfterSerializeDeserialize(s1, INIT_PROTO_VERSION);
- assert(s1.IsAddrV1Compatible());
- CService s2;
- DeserializeFromFuzzingInput(buffer, s2, INIT_PROTO_VERSION | ADDRV2_FORMAT);
- AssertEqualAfterSerializeDeserialize(s2, INIT_PROTO_VERSION | ADDRV2_FORMAT);
-})
+ AssertEqualAfterSerializeDeserialize(s, CNetAddr::V2);
+ if (ser_params.enc == CNetAddr::Encoding::V1) {
+ assert(s.IsAddrV1Compatible());
+ }
+}
FUZZ_TARGET_DESERIALIZE(messageheader_deserialize, {
CMessageHeader mh;
DeserializeFromFuzzingInput(buffer, mh);
(void)mh.IsCommandValid();
})
-FUZZ_TARGET_DESERIALIZE(address_deserialize_v1_notime, {
- CAddress a;
- DeserializeFromFuzzingInput(buffer, a, INIT_PROTO_VERSION);
- // A CAddress without nTime (as is expected under INIT_PROTO_VERSION) will roundtrip
- // in all 5 formats (with/without nTime, v1/v2, network/disk)
- AssertEqualAfterSerializeDeserialize(a, INIT_PROTO_VERSION);
- AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION);
- AssertEqualAfterSerializeDeserialize(a, 0, SER_DISK);
- AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION | ADDRV2_FORMAT);
- AssertEqualAfterSerializeDeserialize(a, ADDRV2_FORMAT, SER_DISK);
-})
-FUZZ_TARGET_DESERIALIZE(address_deserialize_v1_withtime, {
- CAddress a;
- DeserializeFromFuzzingInput(buffer, a, PROTOCOL_VERSION);
- // A CAddress in V1 mode will roundtrip in all 4 formats that have nTime.
- AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION);
- AssertEqualAfterSerializeDeserialize(a, 0, SER_DISK);
- AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION | ADDRV2_FORMAT);
- AssertEqualAfterSerializeDeserialize(a, ADDRV2_FORMAT, SER_DISK);
-})
-FUZZ_TARGET_DESERIALIZE(address_deserialize_v2, {
- CAddress a;
- DeserializeFromFuzzingInput(buffer, a, PROTOCOL_VERSION | ADDRV2_FORMAT);
- // A CAddress in V2 mode will roundtrip in both V2 formats, and also in the V1 formats
- // with time if it's V1 compatible.
- if (a.IsAddrV1Compatible()) {
- AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION);
- AssertEqualAfterSerializeDeserialize(a, 0, SER_DISK);
+FUZZ_TARGET(address_deserialize, .init = initialize_deserialize)
+{
+ FuzzedDataProvider fdp{buffer.data(), buffer.size()};
+ const auto ser_enc{ConsumeDeserializationParams<CNetAddr::SerParams>(fdp)};
+ const auto maybe_a{ConsumeDeserializable<CAddress>(fdp, CAddress::SerParams{{ser_enc}, CAddress::Format::Network})};
+ if (!maybe_a) return;
+ const CAddress& a{*maybe_a};
+ // A CAddress in V1 mode will roundtrip
+ // in all 4 formats (v1/v2, network/disk)
+ if (ser_enc.enc == CNetAddr::Encoding::V1) {
+ AssertEqualAfterSerializeDeserialize(a, CAddress::V1_NETWORK);
+ AssertEqualAfterSerializeDeserialize(a, CAddress::V1_DISK);
+ AssertEqualAfterSerializeDeserialize(a, CAddress::V2_NETWORK);
+ AssertEqualAfterSerializeDeserialize(a, CAddress::V2_DISK);
+ } else {
+ // A CAddress in V2 mode will roundtrip in both V2 formats, and also in the V1 formats
+ // if it's V1 compatible.
+ if (a.IsAddrV1Compatible()) {
+ AssertEqualAfterSerializeDeserialize(a, CAddress::V1_DISK);
+ AssertEqualAfterSerializeDeserialize(a, CAddress::V1_NETWORK);
+ }
+ AssertEqualAfterSerializeDeserialize(a, CAddress::V2_NETWORK);
+ AssertEqualAfterSerializeDeserialize(a, CAddress::V2_DISK);
}
- AssertEqualAfterSerializeDeserialize(a, PROTOCOL_VERSION | ADDRV2_FORMAT);
- AssertEqualAfterSerializeDeserialize(a, ADDRV2_FORMAT, SER_DISK);
-})
+}
FUZZ_TARGET_DESERIALIZE(inv_deserialize, {
CInv i;
DeserializeFromFuzzingInput(buffer, i);
diff --git a/src/test/fuzz/key.cpp b/src/test/fuzz/key.cpp
index a5a579d982..60f4081432 100644
--- a/src/test/fuzz/key.cpp
+++ b/src/test/fuzz/key.cpp
@@ -17,6 +17,7 @@
#include <streams.h>
#include <test/fuzz/FuzzedDataProvider.h>
#include <test/fuzz/fuzz.h>
+#include <test/fuzz/util.h>
#include <util/chaintype.h>
#include <util/strencodings.h>
@@ -312,10 +313,7 @@ FUZZ_TARGET(ellswift_roundtrip, .init = initialize_key)
{
FuzzedDataProvider fdp{buffer.data(), buffer.size()};
- auto key_bytes = fdp.ConsumeBytes<uint8_t>(32);
- key_bytes.resize(32);
- CKey key;
- key.Set(key_bytes.begin(), key_bytes.end(), true);
+ CKey key = ConsumePrivateKey(fdp, /*compressed=*/true);
if (!key.IsValid()) return;
auto ent32 = fdp.ConsumeBytes<std::byte>(32);
@@ -332,17 +330,11 @@ FUZZ_TARGET(bip324_ecdh, .init = initialize_key)
FuzzedDataProvider fdp{buffer.data(), buffer.size()};
// We generate private key, k1.
- auto rnd32 = fdp.ConsumeBytes<uint8_t>(32);
- rnd32.resize(32);
- CKey k1;
- k1.Set(rnd32.begin(), rnd32.end(), true);
+ CKey k1 = ConsumePrivateKey(fdp, /*compressed=*/true);
if (!k1.IsValid()) return;
// They generate private key, k2.
- rnd32 = fdp.ConsumeBytes<uint8_t>(32);
- rnd32.resize(32);
- CKey k2;
- k2.Set(rnd32.begin(), rnd32.end(), true);
+ CKey k2 = ConsumePrivateKey(fdp, /*compressed=*/true);
if (!k2.IsValid()) return;
// We construct an ellswift encoding for our key, k1_ellswift.
diff --git a/src/test/fuzz/message.cpp b/src/test/fuzz/message.cpp
index f839f9e326..b5c95441f8 100644
--- a/src/test/fuzz/message.cpp
+++ b/src/test/fuzz/message.cpp
@@ -28,9 +28,7 @@ FUZZ_TARGET(message, .init = initialize_message)
FuzzedDataProvider fuzzed_data_provider(buffer.data(), buffer.size());
const std::string random_message = fuzzed_data_provider.ConsumeRandomLengthString(1024);
{
- const std::vector<uint8_t> random_bytes = ConsumeRandomLengthByteVector(fuzzed_data_provider);
- CKey private_key;
- private_key.Set(random_bytes.begin(), random_bytes.end(), fuzzed_data_provider.ConsumeBool());
+ CKey private_key = ConsumePrivateKey(fuzzed_data_provider);
std::string signature;
const bool message_signed = MessageSign(private_key, random_message, signature);
if (private_key.IsValid()) {
diff --git a/src/test/fuzz/net.cpp b/src/test/fuzz/net.cpp
index ddf919f2e6..c882bd766a 100644
--- a/src/test/fuzz/net.cpp
+++ b/src/test/fuzz/net.cpp
@@ -53,7 +53,7 @@ FUZZ_TARGET(net, .init = initialize_net)
}
},
[&] {
- const std::optional<CService> service_opt = ConsumeDeserializable<CService>(fuzzed_data_provider);
+ const std::optional<CService> service_opt = ConsumeDeserializable<CService>(fuzzed_data_provider, ConsumeDeserializationParams<CNetAddr::SerParams>(fuzzed_data_provider));
if (!service_opt) {
return;
}
diff --git a/src/test/fuzz/p2p_transport_serialization.cpp b/src/test/fuzz/p2p_transport_serialization.cpp
index 2fa5de5008..6e3ef2a707 100644
--- a/src/test/fuzz/p2p_transport_serialization.cpp
+++ b/src/test/fuzz/p2p_transport_serialization.cpp
@@ -25,6 +25,7 @@ std::vector<std::string> g_all_messages;
void initialize_p2p_transport_serialization()
{
+ ECC_Start();
SelectParams(ChainType::REGTEST);
g_all_messages = getAllNetMessageTypes();
std::sort(g_all_messages.begin(), g_all_messages.end());
@@ -92,7 +93,7 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial
assert(queued);
std::optional<bool> known_more;
while (true) {
- const auto& [to_send, more, _msg_type] = send_transport.GetBytesToSend();
+ const auto& [to_send, more, _msg_type] = send_transport.GetBytesToSend(false);
if (known_more) assert(!to_send.empty() == *known_more);
if (to_send.empty()) break;
send_transport.MarkBytesSent(to_send.size());
@@ -124,11 +125,13 @@ void SimulationTest(Transport& initiator, Transport& responder, R& rng, FuzzedDa
// Vectors with bytes last returned by GetBytesToSend() on transport[i].
std::array<std::vector<uint8_t>, 2> to_send;
- // Last returned 'more' values (if still relevant) by transport[i]->GetBytesToSend().
- std::array<std::optional<bool>, 2> last_more;
+ // Last returned 'more' values (if still relevant) by transport[i]->GetBytesToSend(), for
+ // both have_next_message false and true.
+ std::array<std::optional<bool>, 2> last_more, last_more_next;
- // Whether more bytes to be sent are expected on transport[i].
- std::array<std::optional<bool>, 2> expect_more;
+ // Whether more bytes to be sent are expected on transport[i], before and after
+ // SetMessageToSend().
+ std::array<std::optional<bool>, 2> expect_more, expect_more_next;
// Function to consume a message type.
auto msg_type_fn = [&]() {
@@ -177,18 +180,27 @@ void SimulationTest(Transport& initiator, Transport& responder, R& rng, FuzzedDa
// Wrapper around transport[i]->GetBytesToSend() that performs sanity checks.
auto bytes_to_send_fn = [&](int side) -> Transport::BytesToSend {
- const auto& [bytes, more, msg_type] = transports[side]->GetBytesToSend();
+ // Invoke GetBytesToSend twice (for have_next_message = {false, true}). This function does
+ // not modify state (it's const), and only the "more" return value should differ between
+ // the calls.
+ const auto& [bytes, more_nonext, msg_type] = transports[side]->GetBytesToSend(false);
+ const auto& [bytes_next, more_next, msg_type_next] = transports[side]->GetBytesToSend(true);
// Compare with expected more.
if (expect_more[side].has_value()) assert(!bytes.empty() == *expect_more[side]);
+ // Verify consistency between the two results.
+ assert(bytes == bytes_next);
+ assert(msg_type == msg_type_next);
+ if (more_nonext) assert(more_next);
// Compare with previously reported output.
assert(to_send[side].size() <= bytes.size());
assert(to_send[side] == Span{bytes}.first(to_send[side].size()));
to_send[side].resize(bytes.size());
std::copy(bytes.begin(), bytes.end(), to_send[side].begin());
- // Remember 'more' result.
- last_more[side] = {more};
+ // Remember 'more' results.
+ last_more[side] = {more_nonext};
+ last_more_next[side] = {more_next};
// Return.
- return {bytes, more, msg_type};
+ return {bytes, more_nonext, msg_type};
};
// Function to make side send a new message.
@@ -199,7 +211,8 @@ void SimulationTest(Transport& initiator, Transport& responder, R& rng, FuzzedDa
CSerializedNetMsg msg = next_msg[side].Copy();
bool queued = transports[side]->SetMessageToSend(msg);
// Update expected more data.
- expect_more[side] = std::nullopt;
+ expect_more[side] = expect_more_next[side];
+ expect_more_next[side] = std::nullopt;
// Verify consistency of GetBytesToSend after SetMessageToSend
bytes_to_send_fn(/*side=*/side);
if (queued) {
@@ -223,6 +236,7 @@ void SimulationTest(Transport& initiator, Transport& responder, R& rng, FuzzedDa
// If all to-be-sent bytes were sent, move last_more data to expect_more data.
if (send_now == bytes.size()) {
expect_more[side] = last_more[side];
+ expect_more_next[side] = last_more_next[side];
}
// Remove the bytes from the last reported to-be-sent vector.
assert(to_send[side].size() >= send_now);
@@ -251,6 +265,7 @@ void SimulationTest(Transport& initiator, Transport& responder, R& rng, FuzzedDa
// Clear cached expected 'more' information: if certainly no more data was to be sent
// before, receiving bytes makes this uncertain.
if (expect_more[!side] == false) expect_more[!side] = std::nullopt;
+ if (expect_more_next[!side] == false) expect_more_next[!side] = std::nullopt;
// Verify consistency of GetBytesToSend after ReceivedBytes
bytes_to_send_fn(/*side=*/!side);
bool progress = to_recv.size() < old_len;
@@ -320,6 +335,40 @@ std::unique_ptr<Transport> MakeV1Transport(NodeId nodeid) noexcept
return std::make_unique<V1Transport>(nodeid, SER_NETWORK, INIT_PROTO_VERSION);
}
+template<typename RNG>
+std::unique_ptr<Transport> MakeV2Transport(NodeId nodeid, bool initiator, RNG& rng, FuzzedDataProvider& provider)
+{
+ // Retrieve key
+ auto key = ConsumePrivateKey(provider);
+ if (!key.IsValid()) return {};
+ // Construct garbage
+ size_t garb_len = provider.ConsumeIntegralInRange<size_t>(0, V2Transport::MAX_GARBAGE_LEN);
+ std::vector<uint8_t> garb;
+ if (garb_len <= 64) {
+ // When the garbage length is up to 64 bytes, read it directly from the fuzzer input.
+ garb = provider.ConsumeBytes<uint8_t>(garb_len);
+ garb.resize(garb_len);
+ } else {
+ // If it's longer, generate it from the RNG. This avoids having large amounts of
+ // (hopefully) irrelevant data needing to be stored in the fuzzer data.
+ for (auto& v : garb) v = uint8_t(rng());
+ }
+ // Retrieve entropy
+ auto ent = provider.ConsumeBytes<std::byte>(32);
+ ent.resize(32);
+ // Use as entropy SHA256(ent || garbage). This prevents a situation where the fuzzer manages to
+ // include the garbage terminator (which is a function of both ellswift keys) in the garbage.
+ // This is extremely unlikely (~2^-116) with random keys/garbage, but the fuzzer can choose
+ // both non-randomly and dependently. Since the entropy is hashed anyway inside the ellswift
+ // computation, no coverage should be lost by using a hash as entropy, and it removes the
+ // possibility of garbage that happens to contain what is effectively a hash of the keys.
+ CSHA256().Write(UCharCast(ent.data()), ent.size())
+ .Write(garb.data(), garb.size())
+ .Finalize(UCharCast(ent.data()));
+
+ return std::make_unique<V2Transport>(nodeid, initiator, SER_NETWORK, INIT_PROTO_VERSION, key, ent, garb);
+}
+
} // namespace
FUZZ_TARGET(p2p_transport_bidirectional, .init = initialize_p2p_transport_serialization)
@@ -332,3 +381,25 @@ FUZZ_TARGET(p2p_transport_bidirectional, .init = initialize_p2p_transport_serial
if (!t1 || !t2) return;
SimulationTest(*t1, *t2, rng, provider);
}
+
+FUZZ_TARGET(p2p_transport_bidirectional_v2, .init = initialize_p2p_transport_serialization)
+{
+ // Test with two V2 transports talking to each other.
+ FuzzedDataProvider provider{buffer.data(), buffer.size()};
+ XoRoShiRo128PlusPlus rng(provider.ConsumeIntegral<uint64_t>());
+ auto t1 = MakeV2Transport(NodeId{0}, true, rng, provider);
+ auto t2 = MakeV2Transport(NodeId{1}, false, rng, provider);
+ if (!t1 || !t2) return;
+ SimulationTest(*t1, *t2, rng, provider);
+}
+
+FUZZ_TARGET(p2p_transport_bidirectional_v1v2, .init = initialize_p2p_transport_serialization)
+{
+ // Test with a V1 initiator talking to a V2 responder.
+ FuzzedDataProvider provider{buffer.data(), buffer.size()};
+ XoRoShiRo128PlusPlus rng(provider.ConsumeIntegral<uint64_t>());
+ auto t1 = MakeV1Transport(NodeId{0});
+ auto t2 = MakeV2Transport(NodeId{1}, false, rng, provider);
+ if (!t1 || !t2) return;
+ SimulationTest(*t1, *t2, rng, provider);
+}
diff --git a/src/test/fuzz/rpc.cpp b/src/test/fuzz/rpc.cpp
index 24ec0e4a73..74f06b481a 100644
--- a/src/test/fuzz/rpc.cpp
+++ b/src/test/fuzz/rpc.cpp
@@ -285,9 +285,7 @@ std::string ConsumeScalarRPCArgument(FuzzedDataProvider& fuzzed_data_provider)
},
[&] {
// base58 encoded key
- const std::vector<uint8_t> random_bytes = fuzzed_data_provider.ConsumeBytes<uint8_t>(32);
- CKey key;
- key.Set(random_bytes.begin(), random_bytes.end(), fuzzed_data_provider.ConsumeBool());
+ CKey key = ConsumePrivateKey(fuzzed_data_provider);
if (!key.IsValid()) {
return;
}
@@ -295,9 +293,7 @@ std::string ConsumeScalarRPCArgument(FuzzedDataProvider& fuzzed_data_provider)
},
[&] {
// hex encoded pubkey
- const std::vector<uint8_t> random_bytes = fuzzed_data_provider.ConsumeBytes<uint8_t>(32);
- CKey key;
- key.Set(random_bytes.begin(), random_bytes.end(), fuzzed_data_provider.ConsumeBool());
+ CKey key = ConsumePrivateKey(fuzzed_data_provider);
if (!key.IsValid()) {
return;
}
diff --git a/src/test/fuzz/script_sign.cpp b/src/test/fuzz/script_sign.cpp
index cec98432e1..179715b6ea 100644
--- a/src/test/fuzz/script_sign.cpp
+++ b/src/test/fuzz/script_sign.cpp
@@ -36,7 +36,8 @@ FUZZ_TARGET(script_sign, .init = initialize_script_sign)
const std::vector<uint8_t> key = ConsumeRandomLengthByteVector(fuzzed_data_provider, 128);
{
- CDataStream random_data_stream = ConsumeDataStream(fuzzed_data_provider);
+ DataStream stream{ConsumeDataStream(fuzzed_data_provider)};
+ CDataStream random_data_stream{stream, SER_NETWORK, INIT_PROTO_VERSION}; // temporary copy, to be removed along with the version flag SERIALIZE_TRANSACTION_NO_WITNESS
std::map<CPubKey, KeyOriginInfo> hd_keypaths;
try {
DeserializeHDKeypaths(random_data_stream, key, hd_keypaths);
@@ -79,9 +80,7 @@ FUZZ_TARGET(script_sign, .init = initialize_script_sign)
}
FillableSigningProvider provider;
- CKey k;
- const std::vector<uint8_t> key_data = ConsumeRandomLengthByteVector(fuzzed_data_provider);
- k.Set(key_data.begin(), key_data.end(), fuzzed_data_provider.ConsumeBool());
+ CKey k = ConsumePrivateKey(fuzzed_data_provider);
if (k.IsValid()) {
provider.AddKey(k);
}
diff --git a/src/test/fuzz/util.cpp b/src/test/fuzz/util.cpp
index 9da84fe90e..ca2218e94c 100644
--- a/src/test/fuzz/util.cpp
+++ b/src/test/fuzz/util.cpp
@@ -14,6 +14,19 @@
#include <memory>
+std::vector<uint8_t> ConstructPubKeyBytes(FuzzedDataProvider& fuzzed_data_provider, Span<const uint8_t> byte_data, const bool compressed) noexcept
+{
+ uint8_t pk_type;
+ if (compressed) {
+ pk_type = fuzzed_data_provider.PickValueInArray({0x02, 0x03});
+ } else {
+ pk_type = fuzzed_data_provider.PickValueInArray({0x04, 0x06, 0x07});
+ }
+ std::vector<uint8_t> pk_data{byte_data.begin(), byte_data.begin() + (compressed ? CPubKey::COMPRESSED_SIZE : CPubKey::SIZE)};
+ pk_data[0] = pk_type;
+ return pk_data;
+}
+
CAmount ConsumeMoney(FuzzedDataProvider& fuzzed_data_provider, const std::optional<CAmount>& max) noexcept
{
return fuzzed_data_provider.ConsumeIntegralInRange<CAmount>(0, max.value_or(MAX_MONEY));
@@ -103,16 +116,12 @@ CScript ConsumeScript(FuzzedDataProvider& fuzzed_data_provider, const bool maybe
// navigate the highly structured multisig format.
r_script << fuzzed_data_provider.ConsumeIntegralInRange<int64_t>(0, 22);
int num_data{fuzzed_data_provider.ConsumeIntegralInRange(1, 22)};
- std::vector<uint8_t> pubkey_comp{buffer.begin(), buffer.begin() + CPubKey::COMPRESSED_SIZE};
- pubkey_comp.front() = fuzzed_data_provider.ConsumeIntegralInRange(2, 3); // Set first byte for GetLen() to pass
- std::vector<uint8_t> pubkey_uncomp{buffer.begin(), buffer.begin() + CPubKey::SIZE};
- pubkey_uncomp.front() = fuzzed_data_provider.ConsumeIntegralInRange(4, 7); // Set first byte for GetLen() to pass
while (num_data--) {
- auto& pubkey{fuzzed_data_provider.ConsumeBool() ? pubkey_uncomp : pubkey_comp};
+ auto pubkey_bytes{ConstructPubKeyBytes(fuzzed_data_provider, buffer, fuzzed_data_provider.ConsumeBool())};
if (fuzzed_data_provider.ConsumeBool()) {
- pubkey.back() = num_data; // Make each pubkey different
+ pubkey_bytes.back() = num_data; // Make each pubkey different
}
- r_script << pubkey;
+ r_script << pubkey_bytes;
}
r_script << fuzzed_data_provider.ConsumeIntegralInRange<int64_t>(0, 22);
},
@@ -193,6 +202,16 @@ CTxDestination ConsumeTxDestination(FuzzedDataProvider& fuzzed_data_provider) no
return tx_destination;
}
+CKey ConsumePrivateKey(FuzzedDataProvider& fuzzed_data_provider, std::optional<bool> compressed) noexcept
+{
+ auto key_data = fuzzed_data_provider.ConsumeBytes<uint8_t>(32);
+ key_data.resize(32);
+ CKey key;
+ bool compressed_value = compressed ? *compressed : fuzzed_data_provider.ConsumeBool();
+ key.Set(key_data.begin(), key_data.end(), compressed_value);
+ return key;
+}
+
bool ContainsSpentInput(const CTransaction& tx, const CCoinsViewCache& inputs) noexcept
{
for (const CTxIn& tx_in : tx.vin) {
diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h
index 5d27d2a180..8263cd4c08 100644
--- a/src/test/fuzz/util.h
+++ b/src/test/fuzz/util.h
@@ -11,6 +11,7 @@
#include <compat/compat.h>
#include <consensus/amount.h>
#include <consensus/consensus.h>
+#include <key.h>
#include <merkleblock.h>
#include <primitives/transaction.h>
#include <script/script.h>
@@ -70,9 +71,9 @@ template<typename B = uint8_t>
return BytesToBits(ConsumeRandomLengthByteVector(fuzzed_data_provider, max_length));
}
-[[nodiscard]] inline CDataStream ConsumeDataStream(FuzzedDataProvider& fuzzed_data_provider, const std::optional<size_t>& max_length = std::nullopt) noexcept
+[[nodiscard]] inline DataStream ConsumeDataStream(FuzzedDataProvider& fuzzed_data_provider, const std::optional<size_t>& max_length = std::nullopt) noexcept
{
- return CDataStream{ConsumeRandomLengthByteVector(fuzzed_data_provider, max_length), SER_NETWORK, INIT_PROTO_VERSION};
+ return DataStream{ConsumeRandomLengthByteVector(fuzzed_data_provider, max_length)};
}
[[nodiscard]] inline std::vector<std::string> ConsumeRandomLengthStringVector(FuzzedDataProvider& fuzzed_data_provider, const size_t max_vector_size = 16, const size_t max_string_length = 16) noexcept
@@ -96,6 +97,23 @@ template <typename T>
return r;
}
+template <typename P>
+[[nodiscard]] P ConsumeDeserializationParams(FuzzedDataProvider& fuzzed_data_provider) noexcept;
+
+template <typename T, typename P>
+[[nodiscard]] std::optional<T> ConsumeDeserializable(FuzzedDataProvider& fuzzed_data_provider, const P& params, const std::optional<size_t>& max_length = std::nullopt) noexcept
+{
+ const std::vector<uint8_t> buffer{ConsumeRandomLengthByteVector(fuzzed_data_provider, max_length)};
+ DataStream ds{buffer};
+ T obj;
+ try {
+ ds >> WithParams(params, obj);
+ } catch (const std::ios_base::failure&) {
+ return std::nullopt;
+ }
+ return obj;
+}
+
template <typename T>
[[nodiscard]] inline std::optional<T> ConsumeDeserializable(FuzzedDataProvider& fuzzed_data_provider, const std::optional<size_t>& max_length = std::nullopt) noexcept
{
@@ -165,6 +183,8 @@ template <typename WeakEnumType, size_t size>
[[nodiscard]] CTxDestination ConsumeTxDestination(FuzzedDataProvider& fuzzed_data_provider) noexcept;
+[[nodiscard]] CKey ConsumePrivateKey(FuzzedDataProvider& fuzzed_data_provider, std::optional<bool> compressed = std::nullopt) noexcept;
+
template <typename T>
[[nodiscard]] bool MultiplicationOverflow(const T i, const T j) noexcept
{
diff --git a/src/test/fuzz/util/net.cpp b/src/test/fuzz/util/net.cpp
index 65bc336297..1545e11065 100644
--- a/src/test/fuzz/util/net.cpp
+++ b/src/test/fuzz/util/net.cpp
@@ -55,6 +55,27 @@ CAddress ConsumeAddress(FuzzedDataProvider& fuzzed_data_provider) noexcept
return {ConsumeService(fuzzed_data_provider), ConsumeWeakEnum(fuzzed_data_provider, ALL_SERVICE_FLAGS), NodeSeconds{std::chrono::seconds{fuzzed_data_provider.ConsumeIntegral<uint32_t>()}}};
}
+template <typename P>
+P ConsumeDeserializationParams(FuzzedDataProvider& fuzzed_data_provider) noexcept
+{
+ constexpr std::array ADDR_ENCODINGS{
+ CNetAddr::Encoding::V1,
+ CNetAddr::Encoding::V2,
+ };
+ constexpr std::array ADDR_FORMATS{
+ CAddress::Format::Disk,
+ CAddress::Format::Network,
+ };
+ if constexpr (std::is_same_v<P, CNetAddr::SerParams>) {
+ return P{PickValue(fuzzed_data_provider, ADDR_ENCODINGS)};
+ }
+ if constexpr (std::is_same_v<P, CAddress::SerParams>) {
+ return P{{PickValue(fuzzed_data_provider, ADDR_ENCODINGS)}, PickValue(fuzzed_data_provider, ADDR_FORMATS)};
+ }
+}
+template CNetAddr::SerParams ConsumeDeserializationParams(FuzzedDataProvider&) noexcept;
+template CAddress::SerParams ConsumeDeserializationParams(FuzzedDataProvider&) noexcept;
+
FuzzedSock::FuzzedSock(FuzzedDataProvider& fuzzed_data_provider)
: m_fuzzed_data_provider{fuzzed_data_provider}, m_selectable{fuzzed_data_provider.ConsumeBool()}
{
diff --git a/src/test/net_tests.cpp b/src/test/net_tests.cpp
index ae342a6278..900e311d22 100644
--- a/src/test/net_tests.cpp
+++ b/src/test/net_tests.cpp
@@ -15,6 +15,7 @@
#include <serialize.h>
#include <span.h>
#include <streams.h>
+#include <test/util/random.h>
#include <test/util/setup_common.h>
#include <test/util/validation.h>
#include <timedata.h>
@@ -327,19 +328,20 @@ BOOST_AUTO_TEST_CASE(cnetaddr_tostring_canonical_ipv6)
BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v1)
{
CNetAddr addr;
- CDataStream s(SER_NETWORK, PROTOCOL_VERSION);
+ DataStream s{};
+ const auto ser_params{CAddress::V1_NETWORK};
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "00000000000000000000000000000000");
s.clear();
addr = LookupHost("1.2.3.4", false).value();
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "00000000000000000000ffff01020304");
s.clear();
addr = LookupHost("1a1b:2a2b:3a3b:4a4b:5a5b:6a6b:7a7b:8a8b", false).value();
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "1a1b2a2b3a3b4a4b5a5b6a6b7a7b8a8b");
s.clear();
@@ -347,12 +349,12 @@ BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v1)
BOOST_CHECK(!addr.SetSpecial("6hzph5hv6337r6p2.onion"));
BOOST_REQUIRE(addr.SetSpecial("pg6mmjiyjmcrsslvykfwnntlaru7p5svn6y2ymmju6nubxndf4pscryd.onion"));
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "00000000000000000000000000000000");
s.clear();
addr.SetInternal("a");
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "fd6b88c08724ca978112ca1bbdcafac2");
s.clear();
}
@@ -360,22 +362,20 @@ BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v1)
BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v2)
{
CNetAddr addr;
- CDataStream s(SER_NETWORK, PROTOCOL_VERSION);
- // Add ADDRV2_FORMAT to the version so that the CNetAddr
- // serialize method produces an address in v2 format.
- s.SetVersion(s.GetVersion() | ADDRV2_FORMAT);
+ DataStream s{};
+ const auto ser_params{CAddress::V2_NETWORK};
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "021000000000000000000000000000000000");
s.clear();
addr = LookupHost("1.2.3.4", false).value();
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "010401020304");
s.clear();
addr = LookupHost("1a1b:2a2b:3a3b:4a4b:5a5b:6a6b:7a7b:8a8b", false).value();
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "02101a1b2a2b3a3b4a4b5a5b6a6b7a7b8a8b");
s.clear();
@@ -383,12 +383,12 @@ BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v2)
BOOST_CHECK(!addr.SetSpecial("6hzph5hv6337r6p2.onion"));
BOOST_REQUIRE(addr.SetSpecial("kpgvmscirrdqpekbqjsvw5teanhatztpp2gl6eee4zkowvwfxwenqaid.onion"));
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "042053cd5648488c4707914182655b7664034e09e66f7e8cbf1084e654eb56c5bd88");
s.clear();
BOOST_REQUIRE(addr.SetInternal("a"));
- s << addr;
+ s << WithParams(ser_params, addr);
BOOST_CHECK_EQUAL(HexStr(s), "0210fd6b88c08724ca978112ca1bbdcafac2");
s.clear();
}
@@ -396,16 +396,14 @@ BOOST_AUTO_TEST_CASE(cnetaddr_serialize_v2)
BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
{
CNetAddr addr;
- CDataStream s(SER_NETWORK, PROTOCOL_VERSION);
- // Add ADDRV2_FORMAT to the version so that the CNetAddr
- // unserialize method expects an address in v2 format.
- s.SetVersion(s.GetVersion() | ADDRV2_FORMAT);
+ DataStream s{};
+ const auto ser_params{CAddress::V2_NETWORK};
// Valid IPv4.
s << Span{ParseHex("01" // network type (IPv4)
"04" // address length
"01020304")}; // address
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(addr.IsValid());
BOOST_CHECK(addr.IsIPv4());
BOOST_CHECK(addr.IsAddrV1Compatible());
@@ -416,7 +414,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
s << Span{ParseHex("01" // network type (IPv4)
"04" // address length
"0102")}; // address
- BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure, HasReason("end of data"));
+ BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure, HasReason("end of data"));
BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input.
s.clear();
@@ -424,7 +422,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
s << Span{ParseHex("01" // network type (IPv4)
"05" // address length
"01020304")}; // address
- BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure,
+ BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure,
HasReason("BIP155 IPv4 address with length 5 (should be 4)"));
BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input.
s.clear();
@@ -433,7 +431,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
s << Span{ParseHex("01" // network type (IPv4)
"fd0102" // address length (513 as CompactSize)
"01020304")}; // address
- BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure,
+ BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure,
HasReason("Address too long: 513 > 512"));
BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input.
s.clear();
@@ -442,7 +440,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
s << Span{ParseHex("02" // network type (IPv6)
"10" // address length
"0102030405060708090a0b0c0d0e0f10")}; // address
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(addr.IsValid());
BOOST_CHECK(addr.IsIPv6());
BOOST_CHECK(addr.IsAddrV1Compatible());
@@ -455,7 +453,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"10" // address length
"fd6b88c08724ca978112ca1bbdcafac2")}; // address: 0xfd + sha256("bitcoin")[0:5] +
// sha256(name)[0:10]
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(addr.IsInternal());
BOOST_CHECK(addr.IsAddrV1Compatible());
BOOST_CHECK_EQUAL(addr.ToStringAddr(), "zklycewkdo64v6wc.internal");
@@ -465,7 +463,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
s << Span{ParseHex("02" // network type (IPv6)
"04" // address length
"00")}; // address
- BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure,
+ BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure,
HasReason("BIP155 IPv6 address with length 4 (should be 16)"));
BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input.
s.clear();
@@ -474,7 +472,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
s << Span{ParseHex("02" // network type (IPv6)
"10" // address length
"00000000000000000000ffff01020304")}; // address
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(!addr.IsValid());
BOOST_REQUIRE(s.empty());
@@ -482,7 +480,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
s << Span{ParseHex("02" // network type (IPv6)
"10" // address length
"fd87d87eeb430102030405060708090a")}; // address
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(!addr.IsValid());
BOOST_REQUIRE(s.empty());
@@ -490,7 +488,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
s << Span{ParseHex("03" // network type (TORv2)
"0a" // address length
"f1f2f3f4f5f6f7f8f9fa")}; // address
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(!addr.IsValid());
BOOST_REQUIRE(s.empty());
@@ -500,7 +498,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"79bcc625184b05194975c28b66b66b04" // address
"69f7f6556fb1ac3189a79b40dda32f1f"
)};
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(addr.IsValid());
BOOST_CHECK(addr.IsTor());
BOOST_CHECK(!addr.IsAddrV1Compatible());
@@ -513,7 +511,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"00" // address length
"00" // address
)};
- BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure,
+ BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure,
HasReason("BIP155 TORv3 address with length 0 (should be 32)"));
BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input.
s.clear();
@@ -523,7 +521,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"20" // address length
"a2894dabaec08c0051a481a6dac88b64" // address
"f98232ae42d4b6fd2fa81952dfe36a87")};
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(addr.IsValid());
BOOST_CHECK(addr.IsI2P());
BOOST_CHECK(!addr.IsAddrV1Compatible());
@@ -536,7 +534,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"03" // address length
"00" // address
)};
- BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure,
+ BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure,
HasReason("BIP155 I2P address with length 3 (should be 32)"));
BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input.
s.clear();
@@ -546,7 +544,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"10" // address length
"fc000001000200030004000500060007" // address
)};
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(addr.IsValid());
BOOST_CHECK(addr.IsCJDNS());
BOOST_CHECK(!addr.IsAddrV1Compatible());
@@ -558,7 +556,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"10" // address length
"aa000001000200030004000500060007" // address
)};
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(addr.IsCJDNS());
BOOST_CHECK(!addr.IsValid());
BOOST_REQUIRE(s.empty());
@@ -568,7 +566,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"01" // address length
"00" // address
)};
- BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure,
+ BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure,
HasReason("BIP155 CJDNS address with length 1 (should be 16)"));
BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input.
s.clear();
@@ -578,7 +576,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"fe00000002" // address length (CompactSize's MAX_SIZE)
"01020304050607" // address
)};
- BOOST_CHECK_EXCEPTION(s >> addr, std::ios_base::failure,
+ BOOST_CHECK_EXCEPTION(s >> WithParams(ser_params, addr), std::ios_base::failure,
HasReason("Address too long: 33554432 > 512"));
BOOST_REQUIRE(!s.empty()); // The stream is not consumed on invalid input.
s.clear();
@@ -588,7 +586,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"04" // address length
"01020304" // address
)};
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(!addr.IsValid());
BOOST_REQUIRE(s.empty());
@@ -597,7 +595,7 @@ BOOST_AUTO_TEST_CASE(cnetaddr_unserialize_v2)
"00" // address length
"" // address
)};
- s >> addr;
+ s >> WithParams(ser_params, addr);
BOOST_CHECK(!addr.IsValid());
BOOST_REQUIRE(s.empty());
}
@@ -852,7 +850,7 @@ BOOST_AUTO_TEST_CASE(initial_advertise_from_version_message)
std::chrono::microseconds time_received_dummy{0};
const auto msg_version =
- msg_maker.Make(NetMsgType::VERSION, PROTOCOL_VERSION, services, time, services, peer_us);
+ msg_maker.Make(NetMsgType::VERSION, PROTOCOL_VERSION, services, time, services, WithParams(CAddress::V1_NETWORK, peer_us));
CDataStream msg_version_stream{msg_version.data, SER_NETWORK, PROTOCOL_VERSION};
m_node.peerman->ProcessMessage(
@@ -875,10 +873,10 @@ BOOST_AUTO_TEST_CASE(initial_advertise_from_version_message)
Span<const unsigned char> data,
bool is_incoming) -> void {
if (!is_incoming && msg_type == "addr") {
- CDataStream s(data, SER_NETWORK, PROTOCOL_VERSION);
+ DataStream s{data};
std::vector<CAddress> addresses;
- s >> addresses;
+ s >> WithParams(CAddress::V1_NETWORK, addresses);
for (const auto& addr : addresses) {
if (addr == expected) {
@@ -1008,4 +1006,530 @@ BOOST_AUTO_TEST_CASE(advertise_local_address)
RemoveLocal(addr_cjdns);
}
+namespace {
+
+/** A class for scenario-based tests of V2Transport
+ *
+ * Each V2TransportTester encapsulates a V2Transport (the one being tested), and can be told to
+ * interact with it. To do so, it also encapsulates a BIP324Cipher to act as the other side. A
+ * second V2Transport is not used, as doing so would not permit scenarios that involve sending
+ * invalid data, or ones scenarios using BIP324 features that are not implemented on the sending
+ * side (like decoy packets).
+ */
+class V2TransportTester
+{
+ V2Transport m_transport; //!< V2Transport being tested
+ BIP324Cipher m_cipher; //!< Cipher to help with the other side
+ bool m_test_initiator; //!< Whether m_transport is the initiator (true) or responder (false)
+
+ std::vector<uint8_t> m_sent_garbage; //!< The garbage we've sent to m_transport.
+ std::vector<uint8_t> m_to_send; //!< Bytes we have queued up to send to m_transport.
+ std::vector<uint8_t> m_received; //!< Bytes we have received from m_transport.
+ std::deque<CSerializedNetMsg> m_msg_to_send; //!< Messages to be sent *by* m_transport to us.
+
+public:
+ /** Construct a tester object. test_initiator: whether the tested transport is initiator. */
+ V2TransportTester(bool test_initiator) :
+ m_transport(0, test_initiator, SER_NETWORK, INIT_PROTO_VERSION),
+ m_test_initiator(test_initiator) {}
+
+ /** Data type returned by Interact:
+ *
+ * - std::nullopt: transport error occurred
+ * - otherwise: a vector of
+ * - std::nullopt: invalid message received
+ * - otherwise: a CNetMessage retrieved
+ */
+ using InteractResult = std::optional<std::vector<std::optional<CNetMessage>>>;
+
+ /** Send/receive scheduled/available bytes and messages.
+ *
+ * This is the only function that interacts with the transport being tested; everything else is
+ * scheduling things done by Interact(), or processing things learned by it.
+ */
+ InteractResult Interact()
+ {
+ std::vector<std::optional<CNetMessage>> ret;
+ while (true) {
+ bool progress{false};
+ // Send bytes from m_to_send to the transport.
+ if (!m_to_send.empty()) {
+ Span<const uint8_t> to_send = Span{m_to_send}.first(1 + InsecureRandRange(m_to_send.size()));
+ size_t old_len = to_send.size();
+ if (!m_transport.ReceivedBytes(to_send)) {
+ return std::nullopt; // transport error occurred
+ }
+ if (old_len != to_send.size()) {
+ progress = true;
+ m_to_send.erase(m_to_send.begin(), m_to_send.begin() + (old_len - to_send.size()));
+ }
+ }
+ // Retrieve messages received by the transport.
+ if (m_transport.ReceivedMessageComplete() && (!progress || InsecureRandBool())) {
+ bool reject{false};
+ auto msg = m_transport.GetReceivedMessage({}, reject);
+ if (reject) {
+ ret.push_back(std::nullopt);
+ } else {
+ ret.push_back(std::move(msg));
+ }
+ progress = true;
+ }
+ // Enqueue a message to be sent by the transport to us.
+ if (!m_msg_to_send.empty() && (!progress || InsecureRandBool())) {
+ if (m_transport.SetMessageToSend(m_msg_to_send.front())) {
+ m_msg_to_send.pop_front();
+ progress = true;
+ }
+ }
+ // Receive bytes from the transport.
+ const auto& [recv_bytes, _more, _msg_type] = m_transport.GetBytesToSend(!m_msg_to_send.empty());
+ if (!recv_bytes.empty() && (!progress || InsecureRandBool())) {
+ size_t to_receive = 1 + InsecureRandRange(recv_bytes.size());
+ m_received.insert(m_received.end(), recv_bytes.begin(), recv_bytes.begin() + to_receive);
+ progress = true;
+ m_transport.MarkBytesSent(to_receive);
+ }
+ if (!progress) break;
+ }
+ return ret;
+ }
+
+ /** Expose the cipher. */
+ BIP324Cipher& GetCipher() { return m_cipher; }
+
+ /** Schedule bytes to be sent to the transport. */
+ void Send(Span<const uint8_t> data)
+ {
+ m_to_send.insert(m_to_send.end(), data.begin(), data.end());
+ }
+
+ /** Send V1 version message header to the transport. */
+ void SendV1Version(const CMessageHeader::MessageStartChars& magic)
+ {
+ CMessageHeader hdr(magic, "version", 126 + InsecureRandRange(11));
+ CDataStream ser(SER_NETWORK, CLIENT_VERSION);
+ ser << hdr;
+ m_to_send.insert(m_to_send.end(), UCharCast(ser.data()), UCharCast(ser.data() + ser.size()));
+ }
+
+ /** Schedule bytes to be sent to the transport. */
+ void Send(Span<const std::byte> data) { Send(MakeUCharSpan(data)); }
+
+ /** Schedule our ellswift key to be sent to the transport. */
+ void SendKey() { Send(m_cipher.GetOurPubKey()); }
+
+ /** Schedule specified garbage to be sent to the transport. */
+ void SendGarbage(Span<const uint8_t> garbage)
+ {
+ // Remember the specified garbage (so we can use it for constructing the garbage
+ // authentication packet).
+ m_sent_garbage.assign(garbage.begin(), garbage.end());
+ // Schedule it for sending.
+ Send(m_sent_garbage);
+ }
+
+ /** Schedule garbage (of specified length) to be sent to the transport. */
+ void SendGarbage(size_t garbage_len)
+ {
+ // Generate random garbage and send it.
+ SendGarbage(g_insecure_rand_ctx.randbytes<uint8_t>(garbage_len));
+ }
+
+ /** Schedule garbage (with valid random length) to be sent to the transport. */
+ void SendGarbage()
+ {
+ SendGarbage(InsecureRandRange(V2Transport::MAX_GARBAGE_LEN + 1));
+ }
+
+ /** Schedule a message to be sent to us by the transport. */
+ void AddMessage(std::string m_type, std::vector<uint8_t> payload)
+ {
+ CSerializedNetMsg msg;
+ msg.m_type = std::move(m_type);
+ msg.data = std::move(payload);
+ m_msg_to_send.push_back(std::move(msg));
+ }
+
+ /** Expect ellswift key to have been received from transport and process it.
+ *
+ * Many other V2TransportTester functions cannot be called until after ReceiveKey() has been
+ * called, as no encryption keys are set up before that point.
+ */
+ void ReceiveKey()
+ {
+ // When processing a key, enough bytes need to have been received already.
+ BOOST_REQUIRE(m_received.size() >= EllSwiftPubKey::size());
+ // Initialize the cipher using it (acting as the opposite side of the tested transport).
+ m_cipher.Initialize(MakeByteSpan(m_received).first(EllSwiftPubKey::size()), !m_test_initiator);
+ // Strip the processed bytes off the front of the receive buffer.
+ m_received.erase(m_received.begin(), m_received.begin() + EllSwiftPubKey::size());
+ }
+
+ /** Schedule an encrypted packet with specified content/aad/ignore to be sent to transport
+ * (only after ReceiveKey). */
+ void SendPacket(Span<const uint8_t> content, Span<const uint8_t> aad = {}, bool ignore = false)
+ {
+ // Use cipher to construct ciphertext.
+ std::vector<std::byte> ciphertext;
+ ciphertext.resize(content.size() + BIP324Cipher::EXPANSION);
+ m_cipher.Encrypt(
+ /*contents=*/MakeByteSpan(content),
+ /*aad=*/MakeByteSpan(aad),
+ /*ignore=*/ignore,
+ /*output=*/ciphertext);
+ // Schedule it for sending.
+ Send(ciphertext);
+ }
+
+ /** Schedule garbage terminator and authentication packet to be sent to the transport (only
+ * after ReceiveKey). */
+ void SendGarbageTermAuth(size_t garb_auth_data_len = 0, bool garb_auth_ignore = false)
+ {
+ // Generate random data to include in the garbage authentication packet (ignored by peer).
+ auto garb_auth_data = g_insecure_rand_ctx.randbytes<uint8_t>(garb_auth_data_len);
+ // Schedule the garbage terminator to be sent.
+ Send(m_cipher.GetSendGarbageTerminator());
+ // Schedule the garbage authentication packet to be sent.
+ SendPacket(/*content=*/garb_auth_data, /*aad=*/m_sent_garbage, /*ignore=*/garb_auth_ignore);
+ }
+
+ /** Schedule version packet to be sent to the transport (only after ReceiveKey). */
+ void SendVersion(Span<const uint8_t> version_data = {}, bool vers_ignore = false)
+ {
+ SendPacket(/*content=*/version_data, /*aad=*/{}, /*ignore=*/vers_ignore);
+ }
+
+ /** Expect a packet to have been received from transport, process it, and return its contents
+ * (only after ReceiveKey). By default, decoys are skipped. */
+ std::vector<uint8_t> ReceivePacket(Span<const std::byte> aad = {}, bool skip_decoy = true)
+ {
+ std::vector<uint8_t> contents;
+ // Loop as long as there are ignored packets that are to be skipped.
+ while (true) {
+ // When processing a packet, at least enough bytes for its length descriptor must be received.
+ BOOST_REQUIRE(m_received.size() >= BIP324Cipher::LENGTH_LEN);
+ // Decrypt the content length.
+ size_t size = m_cipher.DecryptLength(MakeByteSpan(Span{m_received}.first(BIP324Cipher::LENGTH_LEN)));
+ // Check that the full packet is in the receive buffer.
+ BOOST_REQUIRE(m_received.size() >= size + BIP324Cipher::EXPANSION);
+ // Decrypt the packet contents.
+ contents.resize(size);
+ bool ignore{false};
+ bool ret = m_cipher.Decrypt(
+ /*input=*/MakeByteSpan(
+ Span{m_received}.first(size + BIP324Cipher::EXPANSION).subspan(BIP324Cipher::LENGTH_LEN)),
+ /*aad=*/aad,
+ /*ignore=*/ignore,
+ /*contents=*/MakeWritableByteSpan(contents));
+ BOOST_CHECK(ret);
+ // Strip the processed packet's bytes off the front of the receive buffer.
+ m_received.erase(m_received.begin(), m_received.begin() + size + BIP324Cipher::EXPANSION);
+ // Stop if the ignore bit is not set on this packet, or if we choose to not honor it.
+ if (!ignore || !skip_decoy) break;
+ }
+ return contents;
+ }
+
+ /** Expect garbage, garbage terminator, and garbage auth packet to have been received, and
+ * process them (only after ReceiveKey). */
+ void ReceiveGarbage()
+ {
+ // Figure out the garbage length.
+ size_t garblen;
+ for (garblen = 0; garblen <= V2Transport::MAX_GARBAGE_LEN; ++garblen) {
+ BOOST_REQUIRE(m_received.size() >= garblen + BIP324Cipher::GARBAGE_TERMINATOR_LEN);
+ auto term_span = MakeByteSpan(Span{m_received}.subspan(garblen, BIP324Cipher::GARBAGE_TERMINATOR_LEN));
+ if (term_span == m_cipher.GetReceiveGarbageTerminator()) break;
+ }
+ // Copy the garbage to a buffer.
+ std::vector<uint8_t> garbage(m_received.begin(), m_received.begin() + garblen);
+ // Strip garbage + garbage terminator off the front of the receive buffer.
+ m_received.erase(m_received.begin(), m_received.begin() + garblen + BIP324Cipher::GARBAGE_TERMINATOR_LEN);
+ // Process the expected garbage authentication packet. Such a packet still functions as one
+ // even when its ignore bit is set to true, so we do not skip decoy packets here.
+ ReceivePacket(/*aad=*/MakeByteSpan(garbage), /*skip_decoy=*/false);
+ }
+
+ /** Expect version packet to have been received, and process it (only after ReceiveKey). */
+ void ReceiveVersion()
+ {
+ auto contents = ReceivePacket();
+ // Version packets from real BIP324 peers are expected to be empty, despite the fact that
+ // this class supports *sending* non-empty version packets (to test that BIP324 peers
+ // correctly ignore version packet contents).
+ BOOST_CHECK(contents.empty());
+ }
+
+ /** Expect application packet to have been received, with specified short id and payload.
+ * (only after ReceiveKey). */
+ void ReceiveMessage(uint8_t short_id, Span<const uint8_t> payload)
+ {
+ auto ret = ReceivePacket();
+ BOOST_CHECK(ret.size() == payload.size() + 1);
+ BOOST_CHECK(ret[0] == short_id);
+ BOOST_CHECK(Span{ret}.subspan(1) == payload);
+ }
+
+ /** Expect application packet to have been received, with specified 12-char message type and
+ * payload (only after ReceiveKey). */
+ void ReceiveMessage(const std::string& m_type, Span<const uint8_t> payload)
+ {
+ auto ret = ReceivePacket();
+ BOOST_REQUIRE(ret.size() == payload.size() + 1 + CMessageHeader::COMMAND_SIZE);
+ BOOST_CHECK(ret[0] == 0);
+ for (unsigned i = 0; i < 12; ++i) {
+ if (i < m_type.size()) {
+ BOOST_CHECK(ret[1 + i] == m_type[i]);
+ } else {
+ BOOST_CHECK(ret[1 + i] == 0);
+ }
+ }
+ BOOST_CHECK(Span{ret}.subspan(1 + CMessageHeader::COMMAND_SIZE) == payload);
+ }
+
+ /** Schedule an encrypted packet with specified message type and payload to be sent to
+ * transport (only after ReceiveKey). */
+ void SendMessage(std::string mtype, Span<const uint8_t> payload)
+ {
+ // Construct contents consisting of 0x00 + 12-byte message type + payload.
+ std::vector<uint8_t> contents(1 + CMessageHeader::COMMAND_SIZE + payload.size());
+ std::copy(mtype.begin(), mtype.end(), reinterpret_cast<char*>(contents.data() + 1));
+ std::copy(payload.begin(), payload.end(), contents.begin() + 1 + CMessageHeader::COMMAND_SIZE);
+ // Send a packet with that as contents.
+ SendPacket(contents);
+ }
+
+ /** Schedule an encrypted packet with specified short message id and payload to be sent to
+ * transport (only after ReceiveKey). */
+ void SendMessage(uint8_t short_id, Span<const uint8_t> payload)
+ {
+ // Construct contents consisting of short_id + payload.
+ std::vector<uint8_t> contents(1 + payload.size());
+ contents[0] = short_id;
+ std::copy(payload.begin(), payload.end(), contents.begin() + 1);
+ // Send a packet with that as contents.
+ SendPacket(contents);
+ }
+
+ /** Introduce a bit error in the data scheduled to be sent. */
+ void Damage()
+ {
+ m_to_send[InsecureRandRange(m_to_send.size())] ^= (uint8_t{1} << InsecureRandRange(8));
+ }
+};
+
+} // namespace
+
+BOOST_AUTO_TEST_CASE(v2transport_test)
+{
+ // A mostly normal scenario, testing a transport in initiator mode.
+ for (int i = 0; i < 10; ++i) {
+ V2TransportTester tester(true);
+ auto ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ tester.SendKey();
+ tester.SendGarbage();
+ tester.ReceiveKey();
+ tester.SendGarbageTermAuth();
+ tester.SendVersion();
+ ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ tester.ReceiveGarbage();
+ tester.ReceiveVersion();
+ auto msg_data_1 = g_insecure_rand_ctx.randbytes<uint8_t>(InsecureRandRange(100000));
+ auto msg_data_2 = g_insecure_rand_ctx.randbytes<uint8_t>(InsecureRandRange(1000));
+ tester.SendMessage(uint8_t(4), msg_data_1); // cmpctblock short id
+ tester.SendMessage(0, {}); // Invalidly encoded message
+ tester.SendMessage("tx", msg_data_2); // 12-character encoded message type
+ ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->size() == 3);
+ BOOST_CHECK((*ret)[0] && (*ret)[0]->m_type == "cmpctblock" && Span{(*ret)[0]->m_recv} == MakeByteSpan(msg_data_1));
+ BOOST_CHECK(!(*ret)[1]);
+ BOOST_CHECK((*ret)[2] && (*ret)[2]->m_type == "tx" && Span{(*ret)[2]->m_recv} == MakeByteSpan(msg_data_2));
+
+ // Then send a message with a bit error, expecting failure.
+ tester.SendMessage("bad", msg_data_1);
+ tester.Damage();
+ ret = tester.Interact();
+ BOOST_CHECK(!ret);
+ }
+
+ // Normal scenario, with a transport in responder node.
+ for (int i = 0; i < 10; ++i) {
+ V2TransportTester tester(false);
+ tester.SendKey();
+ tester.SendGarbage();
+ auto ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ tester.ReceiveKey();
+ tester.SendGarbageTermAuth();
+ tester.SendVersion();
+ ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ tester.ReceiveGarbage();
+ tester.ReceiveVersion();
+ auto msg_data_1 = g_insecure_rand_ctx.randbytes<uint8_t>(InsecureRandRange(100000));
+ auto msg_data_2 = g_insecure_rand_ctx.randbytes<uint8_t>(InsecureRandRange(1000));
+ tester.SendMessage(uint8_t(14), msg_data_1); // inv short id
+ tester.SendMessage(uint8_t(19), msg_data_2); // pong short id
+ ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->size() == 2);
+ BOOST_CHECK((*ret)[0] && (*ret)[0]->m_type == "inv" && Span{(*ret)[0]->m_recv} == MakeByteSpan(msg_data_1));
+ BOOST_CHECK((*ret)[1] && (*ret)[1]->m_type == "pong" && Span{(*ret)[1]->m_recv} == MakeByteSpan(msg_data_2));
+
+ // Then send a too-large message.
+ auto msg_data_3 = g_insecure_rand_ctx.randbytes<uint8_t>(4005000);
+ tester.SendMessage(uint8_t(11), msg_data_3); // getdata short id
+ ret = tester.Interact();
+ BOOST_CHECK(!ret);
+ }
+
+ // Various valid but unusual scenarios.
+ for (int i = 0; i < 50; ++i) {
+ /** Whether an initiator or responder is being tested. */
+ bool initiator = InsecureRandBool();
+ /** Use either 0 bytes or the maximum possible (4095 bytes) garbage length. */
+ size_t garb_len = InsecureRandBool() ? 0 : V2Transport::MAX_GARBAGE_LEN;
+ /** Sometimes, use non-empty contents in the garbage authentication packet (which is to be ignored). */
+ size_t garb_auth_data_len = InsecureRandBool() ? 0 : InsecureRandRange(100000);
+ /** Whether to set the ignore bit on the garbage authentication packet (it still functions as garbage authentication). */
+ bool garb_ignore = InsecureRandBool();
+ /** How many decoy packets to send before the version packet. */
+ unsigned num_ignore_version = InsecureRandRange(10);
+ /** What data to send in the version packet (ignored by BIP324 peers, but reserved for future extensions). */
+ auto ver_data = g_insecure_rand_ctx.randbytes<uint8_t>(InsecureRandBool() ? 0 : InsecureRandRange(1000));
+ /** Whether to immediately send key and garbage out (required for responders, optional otherwise). */
+ bool send_immediately = !initiator || InsecureRandBool();
+ /** How many decoy packets to send before the first and second real message. */
+ unsigned num_decoys_1 = InsecureRandRange(1000), num_decoys_2 = InsecureRandRange(1000);
+ V2TransportTester tester(initiator);
+ if (send_immediately) {
+ tester.SendKey();
+ tester.SendGarbage(garb_len);
+ }
+ auto ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ if (!send_immediately) {
+ tester.SendKey();
+ tester.SendGarbage(garb_len);
+ }
+ tester.ReceiveKey();
+ tester.SendGarbageTermAuth(garb_auth_data_len, garb_ignore);
+ for (unsigned v = 0; v < num_ignore_version; ++v) {
+ size_t ver_ign_data_len = InsecureRandBool() ? 0 : InsecureRandRange(1000);
+ auto ver_ign_data = g_insecure_rand_ctx.randbytes<uint8_t>(ver_ign_data_len);
+ tester.SendVersion(ver_ign_data, true);
+ }
+ tester.SendVersion(ver_data, false);
+ ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ tester.ReceiveGarbage();
+ tester.ReceiveVersion();
+ for (unsigned d = 0; d < num_decoys_1; ++d) {
+ auto decoy_data = g_insecure_rand_ctx.randbytes<uint8_t>(InsecureRandRange(1000));
+ tester.SendPacket(/*content=*/decoy_data, /*aad=*/{}, /*ignore=*/true);
+ }
+ auto msg_data_1 = g_insecure_rand_ctx.randbytes<uint8_t>(InsecureRandRange(4000000));
+ tester.SendMessage(uint8_t(28), msg_data_1);
+ for (unsigned d = 0; d < num_decoys_2; ++d) {
+ auto decoy_data = g_insecure_rand_ctx.randbytes<uint8_t>(InsecureRandRange(1000));
+ tester.SendPacket(/*content=*/decoy_data, /*aad=*/{}, /*ignore=*/true);
+ }
+ auto msg_data_2 = g_insecure_rand_ctx.randbytes<uint8_t>(InsecureRandRange(1000));
+ tester.SendMessage(uint8_t(13), msg_data_2); // headers short id
+ // Send invalidly-encoded message
+ tester.SendMessage(std::string("blocktxn\x00\x00\x00a", CMessageHeader::COMMAND_SIZE), {});
+ tester.SendMessage("foobar", {}); // test receiving unknown message type
+ tester.AddMessage("barfoo", {}); // test sending unknown message type
+ ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->size() == 4);
+ BOOST_CHECK((*ret)[0] && (*ret)[0]->m_type == "addrv2" && Span{(*ret)[0]->m_recv} == MakeByteSpan(msg_data_1));
+ BOOST_CHECK((*ret)[1] && (*ret)[1]->m_type == "headers" && Span{(*ret)[1]->m_recv} == MakeByteSpan(msg_data_2));
+ BOOST_CHECK(!(*ret)[2]);
+ BOOST_CHECK((*ret)[3] && (*ret)[3]->m_type == "foobar" && (*ret)[3]->m_recv.empty());
+ tester.ReceiveMessage("barfoo", {});
+ }
+
+ // Too long garbage (initiator).
+ {
+ V2TransportTester tester(true);
+ auto ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ tester.SendKey();
+ tester.SendGarbage(V2Transport::MAX_GARBAGE_LEN + 1);
+ tester.ReceiveKey();
+ tester.SendGarbageTermAuth();
+ ret = tester.Interact();
+ BOOST_CHECK(!ret);
+ }
+
+ // Too long garbage (responder).
+ {
+ V2TransportTester tester(false);
+ tester.SendKey();
+ tester.SendGarbage(V2Transport::MAX_GARBAGE_LEN + 1);
+ auto ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ tester.ReceiveKey();
+ tester.SendGarbageTermAuth();
+ ret = tester.Interact();
+ BOOST_CHECK(!ret);
+ }
+
+ // Send garbage that includes the first 15 garbage terminator bytes somewhere.
+ {
+ V2TransportTester tester(true);
+ auto ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ tester.SendKey();
+ tester.ReceiveKey();
+ /** The number of random garbage bytes before the included first 15 bytes of terminator. */
+ size_t len_before = InsecureRandRange(V2Transport::MAX_GARBAGE_LEN - 16 + 1);
+ /** The number of random garbage bytes after it. */
+ size_t len_after = InsecureRandRange(V2Transport::MAX_GARBAGE_LEN - 16 - len_before + 1);
+ // Construct len_before + 16 + len_after random bytes.
+ auto garbage = g_insecure_rand_ctx.randbytes<uint8_t>(len_before + 16 + len_after);
+ // Replace the designed 16 bytes in the middle with the to-be-sent garbage terminator.
+ auto garb_term = MakeUCharSpan(tester.GetCipher().GetSendGarbageTerminator());
+ std::copy(garb_term.begin(), garb_term.begin() + 16, garbage.begin() + len_before);
+ // Introduce a bit error in the last byte of that copied garbage terminator, making only
+ // the first 15 of them match.
+ garbage[len_before + 15] ^= (uint8_t(1) << InsecureRandRange(8));
+ tester.SendGarbage(garbage);
+ tester.SendGarbageTermAuth();
+ tester.SendVersion();
+ ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->empty());
+ tester.ReceiveGarbage();
+ tester.ReceiveVersion();
+ auto msg_data_1 = g_insecure_rand_ctx.randbytes<uint8_t>(4000000); // test that receiving 4M payload works
+ auto msg_data_2 = g_insecure_rand_ctx.randbytes<uint8_t>(4000000); // test that sending 4M payload works
+ tester.SendMessage(uint8_t(InsecureRandRange(223) + 33), {}); // unknown short id
+ tester.SendMessage(uint8_t(2), msg_data_1); // "block" short id
+ tester.AddMessage("blocktxn", msg_data_2); // schedule blocktxn to be sent to us
+ ret = tester.Interact();
+ BOOST_REQUIRE(ret && ret->size() == 2);
+ BOOST_CHECK(!(*ret)[0]);
+ BOOST_CHECK((*ret)[1] && (*ret)[1]->m_type == "block" && Span{(*ret)[1]->m_recv} == MakeByteSpan(msg_data_1));
+ tester.ReceiveMessage(uint8_t(3), msg_data_2); // "blocktxn" short id
+ }
+
+ // Send correct network's V1 header
+ {
+ V2TransportTester tester(false);
+ tester.SendV1Version(Params().MessageStart());
+ auto ret = tester.Interact();
+ BOOST_CHECK(ret);
+ }
+
+ // Send wrong network's V1 header
+ {
+ V2TransportTester tester(false);
+ tester.SendV1Version(CChainParams::Main()->MessageStart());
+ auto ret = tester.Interact();
+ BOOST_CHECK(!ret);
+ }
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/test/netbase_tests.cpp b/src/test/netbase_tests.cpp
index 05953bfd10..e22bf7e7c0 100644
--- a/src/test/netbase_tests.cpp
+++ b/src/test/netbase_tests.cpp
@@ -559,35 +559,35 @@ static constexpr const char* stream_addrv2_hex =
BOOST_AUTO_TEST_CASE(caddress_serialize_v1)
{
- CDataStream s(SER_NETWORK, PROTOCOL_VERSION);
+ DataStream s{};
- s << fixture_addresses;
+ s << WithParams(CAddress::V1_NETWORK, fixture_addresses);
BOOST_CHECK_EQUAL(HexStr(s), stream_addrv1_hex);
}
BOOST_AUTO_TEST_CASE(caddress_unserialize_v1)
{
- CDataStream s(ParseHex(stream_addrv1_hex), SER_NETWORK, PROTOCOL_VERSION);
+ DataStream s{ParseHex(stream_addrv1_hex)};
std::vector<CAddress> addresses_unserialized;
- s >> addresses_unserialized;
+ s >> WithParams(CAddress::V1_NETWORK, addresses_unserialized);
BOOST_CHECK(fixture_addresses == addresses_unserialized);
}
BOOST_AUTO_TEST_CASE(caddress_serialize_v2)
{
- CDataStream s(SER_NETWORK, PROTOCOL_VERSION | ADDRV2_FORMAT);
+ DataStream s{};
- s << fixture_addresses;
+ s << WithParams(CAddress::V2_NETWORK, fixture_addresses);
BOOST_CHECK_EQUAL(HexStr(s), stream_addrv2_hex);
}
BOOST_AUTO_TEST_CASE(caddress_unserialize_v2)
{
- CDataStream s(ParseHex(stream_addrv2_hex), SER_NETWORK, PROTOCOL_VERSION | ADDRV2_FORMAT);
+ DataStream s{ParseHex(stream_addrv2_hex)};
std::vector<CAddress> addresses_unserialized;
- s >> addresses_unserialized;
+ s >> WithParams(CAddress::V2_NETWORK, addresses_unserialized);
BOOST_CHECK(fixture_addresses == addresses_unserialized);
}
diff --git a/src/test/serialize_tests.cpp b/src/test/serialize_tests.cpp
index 2e862621bc..2f2bb6698c 100644
--- a/src/test/serialize_tests.cpp
+++ b/src/test/serialize_tests.cpp
@@ -9,6 +9,7 @@
#include <util/strencodings.h>
#include <stdint.h>
+#include <string>
#include <boost/test/unit_test.hpp>
@@ -254,4 +255,146 @@ BOOST_AUTO_TEST_CASE(class_methods)
}
}
+enum class BaseFormat {
+ RAW,
+ HEX,
+};
+
+/// (Un)serialize a number as raw byte or 2 hexadecimal chars.
+class Base
+{
+public:
+ uint8_t m_base_data;
+
+ Base() : m_base_data(17) {}
+ explicit Base(uint8_t data) : m_base_data(data) {}
+
+ template <typename Stream>
+ void Serialize(Stream& s) const
+ {
+ if (s.GetParams() == BaseFormat::RAW) {
+ s << m_base_data;
+ } else {
+ s << Span{HexStr(Span{&m_base_data, 1})};
+ }
+ }
+
+ template <typename Stream>
+ void Unserialize(Stream& s)
+ {
+ if (s.GetParams() == BaseFormat::RAW) {
+ s >> m_base_data;
+ } else {
+ std::string hex{"aa"};
+ s >> Span{hex}.first(hex.size());
+ m_base_data = TryParseHex<uint8_t>(hex).value().at(0);
+ }
+ }
+};
+
+class DerivedAndBaseFormat
+{
+public:
+ BaseFormat m_base_format;
+
+ enum class DerivedFormat {
+ LOWER,
+ UPPER,
+ } m_derived_format;
+};
+
+class Derived : public Base
+{
+public:
+ std::string m_derived_data;
+
+ SERIALIZE_METHODS_PARAMS(Derived, obj, DerivedAndBaseFormat, fmt)
+ {
+ READWRITE(WithParams(fmt.m_base_format, AsBase<Base>(obj)));
+
+ if (ser_action.ForRead()) {
+ std::string str;
+ s >> str;
+ SER_READ(obj, obj.m_derived_data = str);
+ } else {
+ s << (fmt.m_derived_format == DerivedAndBaseFormat::DerivedFormat::LOWER ?
+ ToLower(obj.m_derived_data) :
+ ToUpper(obj.m_derived_data));
+ }
+ }
+};
+
+BOOST_AUTO_TEST_CASE(with_params_base)
+{
+ Base b{0x0F};
+
+ DataStream stream;
+
+ stream << WithParams(BaseFormat::RAW, b);
+ BOOST_CHECK_EQUAL(stream.str(), "\x0F");
+
+ b.m_base_data = 0;
+ stream >> WithParams(BaseFormat::RAW, b);
+ BOOST_CHECK_EQUAL(b.m_base_data, 0x0F);
+
+ stream.clear();
+
+ stream << WithParams(BaseFormat::HEX, b);
+ BOOST_CHECK_EQUAL(stream.str(), "0f");
+
+ b.m_base_data = 0;
+ stream >> WithParams(BaseFormat::HEX, b);
+ BOOST_CHECK_EQUAL(b.m_base_data, 0x0F);
+}
+
+BOOST_AUTO_TEST_CASE(with_params_vector_of_base)
+{
+ std::vector<Base> v{Base{0x0F}, Base{0xFF}};
+
+ DataStream stream;
+
+ stream << WithParams(BaseFormat::RAW, v);
+ BOOST_CHECK_EQUAL(stream.str(), "\x02\x0F\xFF");
+
+ v[0].m_base_data = 0;
+ v[1].m_base_data = 0;
+ stream >> WithParams(BaseFormat::RAW, v);
+ BOOST_CHECK_EQUAL(v[0].m_base_data, 0x0F);
+ BOOST_CHECK_EQUAL(v[1].m_base_data, 0xFF);
+
+ stream.clear();
+
+ stream << WithParams(BaseFormat::HEX, v);
+ BOOST_CHECK_EQUAL(stream.str(), "\x02"
+ "0fff");
+
+ v[0].m_base_data = 0;
+ v[1].m_base_data = 0;
+ stream >> WithParams(BaseFormat::HEX, v);
+ BOOST_CHECK_EQUAL(v[0].m_base_data, 0x0F);
+ BOOST_CHECK_EQUAL(v[1].m_base_data, 0xFF);
+}
+
+BOOST_AUTO_TEST_CASE(with_params_derived)
+{
+ Derived d;
+ d.m_base_data = 0x0F;
+ d.m_derived_data = "xY";
+
+ DerivedAndBaseFormat fmt;
+
+ DataStream stream;
+
+ fmt.m_base_format = BaseFormat::RAW;
+ fmt.m_derived_format = DerivedAndBaseFormat::DerivedFormat::LOWER;
+ stream << WithParams(fmt, d);
+
+ fmt.m_base_format = BaseFormat::HEX;
+ fmt.m_derived_format = DerivedAndBaseFormat::DerivedFormat::UPPER;
+ stream << WithParams(fmt, d);
+
+ BOOST_CHECK_EQUAL(stream.str(), "\x0F\x02xy"
+ "0f\x02XY");
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/test/streams_tests.cpp b/src/test/streams_tests.cpp
index 5232175824..589a2fd766 100644
--- a/src/test/streams_tests.cpp
+++ b/src/test/streams_tests.cpp
@@ -553,12 +553,12 @@ BOOST_AUTO_TEST_CASE(streams_buffered_file_rand)
BOOST_AUTO_TEST_CASE(streams_hashed)
{
- CDataStream stream(SER_NETWORK, INIT_PROTO_VERSION);
+ DataStream stream{};
HashedSourceWriter hash_writer{stream};
const std::string data{"bitcoin"};
hash_writer << data;
- CHashVerifier hash_verifier{&stream};
+ HashVerifier hash_verifier{stream};
std::string result;
hash_verifier >> result;
BOOST_CHECK_EQUAL(data, result);
diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp
index 8015db3e80..dc64c0b4c1 100644
--- a/src/test/util/net.cpp
+++ b/src/test/util/net.cpp
@@ -33,9 +33,9 @@ void ConnmanTestMsg::Handshake(CNode& node,
Using<CustomUintFormatter<8>>(remote_services), //
int64_t{}, // dummy time
int64_t{}, // ignored service bits
- CService{}, // dummy
+ WithParams(CNetAddr::V1, CService{}), // dummy
int64_t{}, // ignored service bits
- CService{}, // ignored
+ WithParams(CNetAddr::V1, CService{}), // ignored
uint64_t{1}, // dummy nonce
std::string{}, // dummy subver
int32_t{}, // dummy starting_height
@@ -78,7 +78,7 @@ void ConnmanTestMsg::FlushSendBuffer(CNode& node) const
node.vSendMsg.clear();
node.m_send_memusage = 0;
while (true) {
- const auto& [to_send, _more, _msg_type] = node.m_transport->GetBytesToSend();
+ const auto& [to_send, _more, _msg_type] = node.m_transport->GetBytesToSend(false);
if (to_send.empty()) break;
node.m_transport->MarkBytesSent(to_send.size());
}
@@ -90,7 +90,7 @@ bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg&& ser_msg) co
assert(queued);
bool complete{false};
while (true) {
- const auto& [to_send, _more, _msg_type] = node.m_transport->GetBytesToSend();
+ const auto& [to_send, _more, _msg_type] = node.m_transport->GetBytesToSend(false);
if (to_send.empty()) break;
NodeReceiveMsgBytes(node, to_send, complete);
node.m_transport->MarkBytesSent(to_send.size());
diff --git a/src/validation.cpp b/src/validation.cpp
index 0b6327ec55..e3a00e4241 100644
--- a/src/validation.cpp
+++ b/src/validation.cpp
@@ -11,6 +11,7 @@
#include <arith_uint256.h>
#include <chain.h>
#include <checkqueue.h>
+#include <clientversion.h>
#include <consensus/amount.h>
#include <consensus/consensus.h>
#include <consensus/merkle.h>
diff --git a/src/version.h b/src/version.h
index 611a670314..204df3758b 100644
--- a/src/version.h
+++ b/src/version.h
@@ -36,6 +36,6 @@ static const int INVALID_CB_NO_BAN_VERSION = 70015;
static const int WTXID_RELAY_VERSION = 70016;
// Make sure that none of the values above collide with
-// `SERIALIZE_TRANSACTION_NO_WITNESS` or `ADDRV2_FORMAT`.
+// `SERIALIZE_TRANSACTION_NO_WITNESS`.
#endif // BITCOIN_VERSION_H