aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorPieter Wuille <pieter@wuille.net>2023-07-05 16:22:52 -0400
committerPieter Wuille <pieter@wuille.net>2023-08-23 19:56:24 -0400
commit93594e42c3f92d82427d2b284ff0f94cdbebe99c (patch)
tree2f7d4f12f650accebc2a82ae06a5f99ab081980b /src
parent23f3f402fca346302fe424427ae4077d8a458cbb (diff)
downloadbitcoin-93594e42c3f92d82427d2b284ff0f94cdbebe99c.tar.xz
refactor: merge transport serializer and deserializer into Transport class
This allows state that is shared between both directions to be encapsulated into a single object. Specifically the v2 transport protocol introduced by BIP324 has sending state (the encryption keys) that depends on received messages (the DH key exchange). Having a single object for both means it can hide logic from callers related to that key exchange and other interactions.
Diffstat (limited to 'src')
-rw-r--r--src/net.cpp21
-rw-r--r--src/net.h41
-rw-r--r--src/test/fuzz/p2p_transport_serialization.cpp15
-rw-r--r--src/test/util/net.cpp2
4 files changed, 37 insertions, 42 deletions
diff --git a/src/net.cpp b/src/net.cpp
index 53a2dcf125..fa20136bb1 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -681,16 +681,16 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
nRecvBytes += msg_bytes.size();
while (msg_bytes.size() > 0) {
// absorb network data
- int handled = m_deserializer->Read(msg_bytes);
+ int handled = m_transport->Read(msg_bytes);
if (handled < 0) {
// Serious header problem, disconnect from the peer.
return false;
}
- if (m_deserializer->Complete()) {
+ if (m_transport->Complete()) {
// decompose a transport agnostic CNetMessage from the deserializer
bool reject_message{false};
- CNetMessage msg = m_deserializer->GetMessage(time, reject_message);
+ CNetMessage msg = m_transport->GetMessage(time, reject_message);
if (reject_message) {
// Message deserialization failed. Drop the message but don't disconnect the peer.
// store the size of the corrupt message
@@ -717,7 +717,7 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
return true;
}
-int V1TransportDeserializer::readHeader(Span<const uint8_t> msg_bytes)
+int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
{
// copy data to temporary parsing buffer
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
@@ -757,7 +757,7 @@ int V1TransportDeserializer::readHeader(Span<const uint8_t> msg_bytes)
return nCopy;
}
-int V1TransportDeserializer::readData(Span<const uint8_t> msg_bytes)
+int V1Transport::readData(Span<const uint8_t> msg_bytes)
{
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
@@ -774,7 +774,7 @@ int V1TransportDeserializer::readData(Span<const uint8_t> msg_bytes)
return nCopy;
}
-const uint256& V1TransportDeserializer::GetMessageHash() const
+const uint256& V1Transport::GetMessageHash() const
{
assert(Complete());
if (data_hash.IsNull())
@@ -782,7 +782,7 @@ const uint256& V1TransportDeserializer::GetMessageHash() const
return data_hash;
}
-CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, bool& reject_message)
+CNetMessage V1Transport::GetMessage(const std::chrono::microseconds time, bool& reject_message)
{
// Initialize out parameter
reject_message = false;
@@ -819,7 +819,7 @@ CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds
return msg;
}
-void V1TransportSerializer::prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const
+void V1Transport::prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const
{
// create dbl-sha256 checksum
uint256 hash = Hash(msg.data);
@@ -2822,8 +2822,7 @@ CNode::CNode(NodeId idIn,
ConnectionType conn_type_in,
bool inbound_onion,
CNodeOptions&& node_opts)
- : m_deserializer{std::make_unique<V1TransportDeserializer>(V1TransportDeserializer(Params(), idIn, SER_NETWORK, INIT_PROTO_VERSION))},
- m_serializer{std::make_unique<V1TransportSerializer>(V1TransportSerializer())},
+ : m_transport{std::make_unique<V1Transport>(Params(), idIn, SER_NETWORK, INIT_PROTO_VERSION)},
m_permission_flags{node_opts.permission_flags},
m_sock{sock},
m_connected{GetTime<std::chrono::seconds>()},
@@ -2908,7 +2907,7 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg)
// make sure we use the appropriate network transport format
std::vector<unsigned char> serializedHeader;
- pnode->m_serializer->prepareForTransport(msg, serializedHeader);
+ pnode->m_transport->prepareForTransport(msg, serializedHeader);
size_t nTotalSize = nMessageSize + serializedHeader.size();
size_t nBytesSent = 0;
diff --git a/src/net.h b/src/net.h
index 3c1221f518..ca6899a83a 100644
--- a/src/net.h
+++ b/src/net.h
@@ -253,24 +253,31 @@ public:
}
};
-/** The TransportDeserializer takes care of holding and deserializing the
- * network receive buffer. It can deserialize the network buffer into a
- * transport protocol agnostic CNetMessage (message type & payload)
- */
-class TransportDeserializer {
+/** The Transport converts one connection's sent messages to wire bytes, and received bytes back. */
+class Transport {
public:
+ virtual ~Transport() {}
+
+ // 1. Receiver side functions, for decoding bytes received on the wire into transport protocol
+ // agnostic CNetMessage (message type & payload) objects. Callers must guarantee that none of
+ // these functions are called concurrently w.r.t. one another.
+
// returns true if the current deserialization is complete
virtual bool Complete() const = 0;
- // set the serialization context version
+ // set the deserialization context version
virtual void SetVersion(int version) = 0;
/** read and deserialize data, advances msg_bytes data pointer */
virtual int Read(Span<const uint8_t>& msg_bytes) = 0;
// decomposes a message from the context
virtual CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) = 0;
- virtual ~TransportDeserializer() {}
+
+ // 2. Sending side functions:
+
+ // prepare message for transport (header construction, error-correction computation, payload encryption, etc.)
+ virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const = 0;
};
-class V1TransportDeserializer final : public TransportDeserializer
+class V1Transport final : public Transport
{
private:
const CChainParams& m_chain_params;
@@ -300,7 +307,7 @@ private:
}
public:
- V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
+ V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
: m_chain_params(chain_params),
m_node_id(node_id),
hdrbuf(nTypeIn, nVersionIn),
@@ -331,19 +338,7 @@ public:
return ret;
}
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override;
-};
-/** The TransportSerializer prepares messages for the network transport
- */
-class TransportSerializer {
-public:
- // prepare message for transport (header construction, error-correction computation, payload encryption, etc.)
- virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const = 0;
- virtual ~TransportSerializer() {}
-};
-
-class V1TransportSerializer : public TransportSerializer {
-public:
void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override;
};
@@ -359,8 +354,8 @@ struct CNodeOptions
class CNode
{
public:
- const std::unique_ptr<TransportDeserializer> m_deserializer; // Used only by SocketHandler thread
- const std::unique_ptr<const TransportSerializer> m_serializer;
+ /** Transport serializer/deserializer. The receive side functions are only called under cs_vRecv. */
+ const std::unique_ptr<Transport> m_transport;
const NetPermissionFlags m_permission_flags;
diff --git a/src/test/fuzz/p2p_transport_serialization.cpp b/src/test/fuzz/p2p_transport_serialization.cpp
index 78350a600e..5e44421f1d 100644
--- a/src/test/fuzz/p2p_transport_serialization.cpp
+++ b/src/test/fuzz/p2p_transport_serialization.cpp
@@ -24,9 +24,10 @@ void initialize_p2p_transport_serialization()
FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serialization)
{
- // Construct deserializer, with a dummy NodeId
- V1TransportDeserializer deserializer{Params(), NodeId{0}, SER_NETWORK, INIT_PROTO_VERSION};
- V1TransportSerializer serializer{};
+ // Construct transports for both sides, with dummy NodeIds.
+ V1Transport recv_transport{Params(), NodeId{0}, SER_NETWORK, INIT_PROTO_VERSION};
+ V1Transport send_transport{Params(), NodeId{1}, SER_NETWORK, INIT_PROTO_VERSION};
+
FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()};
auto checksum_assist = fuzzed_data_provider.ConsumeBool();
@@ -63,14 +64,14 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial
mutable_msg_bytes.insert(mutable_msg_bytes.end(), payload_bytes.begin(), payload_bytes.end());
Span<const uint8_t> msg_bytes{mutable_msg_bytes};
while (msg_bytes.size() > 0) {
- const int handled = deserializer.Read(msg_bytes);
+ const int handled = recv_transport.Read(msg_bytes);
if (handled < 0) {
break;
}
- if (deserializer.Complete()) {
+ if (recv_transport.Complete()) {
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
bool reject_message{false};
- CNetMessage msg = deserializer.GetMessage(m_time, reject_message);
+ CNetMessage msg = recv_transport.GetMessage(m_time, reject_message);
assert(msg.m_type.size() <= CMessageHeader::COMMAND_SIZE);
assert(msg.m_raw_message_size <= mutable_msg_bytes.size());
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size);
@@ -78,7 +79,7 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial
std::vector<unsigned char> header;
auto msg2 = CNetMsgMaker{msg.m_recv.GetVersion()}.Make(msg.m_type, Span{msg.m_recv});
- serializer.prepareForTransport(msg2, header);
+ send_transport.prepareForTransport(msg2, header);
}
}
}
diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp
index 3f72384b3b..0031770028 100644
--- a/src/test/util/net.cpp
+++ b/src/test/util/net.cpp
@@ -73,7 +73,7 @@ void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span<const uint8_t> msg_by
bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const
{
std::vector<uint8_t> ser_msg_header;
- node.m_serializer->prepareForTransport(ser_msg, ser_msg_header);
+ node.m_transport->prepareForTransport(ser_msg, ser_msg_header);
bool complete;
NodeReceiveMsgBytes(node, ser_msg_header, complete);