diff options
author | Pieter Wuille <pieter@wuille.net> | 2023-07-05 16:22:52 -0400 |
---|---|---|
committer | Pieter Wuille <pieter@wuille.net> | 2023-08-23 19:56:24 -0400 |
commit | 93594e42c3f92d82427d2b284ff0f94cdbebe99c (patch) | |
tree | 2f7d4f12f650accebc2a82ae06a5f99ab081980b /src | |
parent | 23f3f402fca346302fe424427ae4077d8a458cbb (diff) |
refactor: merge transport serializer and deserializer into Transport class
This allows state that is shared between both directions to be encapsulated
into a single object. Specifically the v2 transport protocol introduced by
BIP324 has sending state (the encryption keys) that depends on received
messages (the DH key exchange). Having a single object for both means it can
hide logic from callers related to that key exchange and other interactions.
Diffstat (limited to 'src')
-rw-r--r-- | src/net.cpp | 21 | ||||
-rw-r--r-- | src/net.h | 41 | ||||
-rw-r--r-- | src/test/fuzz/p2p_transport_serialization.cpp | 15 | ||||
-rw-r--r-- | src/test/util/net.cpp | 2 |
4 files changed, 37 insertions, 42 deletions
diff --git a/src/net.cpp b/src/net.cpp index 53a2dcf125..fa20136bb1 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -681,16 +681,16 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete) nRecvBytes += msg_bytes.size(); while (msg_bytes.size() > 0) { // absorb network data - int handled = m_deserializer->Read(msg_bytes); + int handled = m_transport->Read(msg_bytes); if (handled < 0) { // Serious header problem, disconnect from the peer. return false; } - if (m_deserializer->Complete()) { + if (m_transport->Complete()) { // decompose a transport agnostic CNetMessage from the deserializer bool reject_message{false}; - CNetMessage msg = m_deserializer->GetMessage(time, reject_message); + CNetMessage msg = m_transport->GetMessage(time, reject_message); if (reject_message) { // Message deserialization failed. Drop the message but don't disconnect the peer. // store the size of the corrupt message @@ -717,7 +717,7 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete) return true; } -int V1TransportDeserializer::readHeader(Span<const uint8_t> msg_bytes) +int V1Transport::readHeader(Span<const uint8_t> msg_bytes) { // copy data to temporary parsing buffer unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos; @@ -757,7 +757,7 @@ int V1TransportDeserializer::readHeader(Span<const uint8_t> msg_bytes) return nCopy; } -int V1TransportDeserializer::readData(Span<const uint8_t> msg_bytes) +int V1Transport::readData(Span<const uint8_t> msg_bytes) { unsigned int nRemaining = hdr.nMessageSize - nDataPos; unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size()); @@ -774,7 +774,7 @@ int V1TransportDeserializer::readData(Span<const uint8_t> msg_bytes) return nCopy; } -const uint256& V1TransportDeserializer::GetMessageHash() const +const uint256& V1Transport::GetMessageHash() const { assert(Complete()); if (data_hash.IsNull()) @@ -782,7 +782,7 @@ const uint256& V1TransportDeserializer::GetMessageHash() const return data_hash; } -CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, bool& reject_message) +CNetMessage V1Transport::GetMessage(const std::chrono::microseconds time, bool& reject_message) { // Initialize out parameter reject_message = false; @@ -819,7 +819,7 @@ CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds return msg; } -void V1TransportSerializer::prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const +void V1Transport::prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const { // create dbl-sha256 checksum uint256 hash = Hash(msg.data); @@ -2822,8 +2822,7 @@ CNode::CNode(NodeId idIn, ConnectionType conn_type_in, bool inbound_onion, CNodeOptions&& node_opts) - : m_deserializer{std::make_unique<V1TransportDeserializer>(V1TransportDeserializer(Params(), idIn, SER_NETWORK, INIT_PROTO_VERSION))}, - m_serializer{std::make_unique<V1TransportSerializer>(V1TransportSerializer())}, + : m_transport{std::make_unique<V1Transport>(Params(), idIn, SER_NETWORK, INIT_PROTO_VERSION)}, m_permission_flags{node_opts.permission_flags}, m_sock{sock}, m_connected{GetTime<std::chrono::seconds>()}, @@ -2908,7 +2907,7 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg) // make sure we use the appropriate network transport format std::vector<unsigned char> serializedHeader; - pnode->m_serializer->prepareForTransport(msg, serializedHeader); + pnode->m_transport->prepareForTransport(msg, serializedHeader); size_t nTotalSize = nMessageSize + serializedHeader.size(); size_t nBytesSent = 0; @@ -253,24 +253,31 @@ public: } }; -/** The TransportDeserializer takes care of holding and deserializing the - * network receive buffer. It can deserialize the network buffer into a - * transport protocol agnostic CNetMessage (message type & payload) - */ -class TransportDeserializer { +/** The Transport converts one connection's sent messages to wire bytes, and received bytes back. */ +class Transport { public: + virtual ~Transport() {} + + // 1. Receiver side functions, for decoding bytes received on the wire into transport protocol + // agnostic CNetMessage (message type & payload) objects. Callers must guarantee that none of + // these functions are called concurrently w.r.t. one another. + // returns true if the current deserialization is complete virtual bool Complete() const = 0; - // set the serialization context version + // set the deserialization context version virtual void SetVersion(int version) = 0; /** read and deserialize data, advances msg_bytes data pointer */ virtual int Read(Span<const uint8_t>& msg_bytes) = 0; // decomposes a message from the context virtual CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) = 0; - virtual ~TransportDeserializer() {} + + // 2. Sending side functions: + + // prepare message for transport (header construction, error-correction computation, payload encryption, etc.) + virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const = 0; }; -class V1TransportDeserializer final : public TransportDeserializer +class V1Transport final : public Transport { private: const CChainParams& m_chain_params; @@ -300,7 +307,7 @@ private: } public: - V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn) + V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn) : m_chain_params(chain_params), m_node_id(node_id), hdrbuf(nTypeIn, nVersionIn), @@ -331,19 +338,7 @@ public: return ret; } CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override; -}; -/** The TransportSerializer prepares messages for the network transport - */ -class TransportSerializer { -public: - // prepare message for transport (header construction, error-correction computation, payload encryption, etc.) - virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const = 0; - virtual ~TransportSerializer() {} -}; - -class V1TransportSerializer : public TransportSerializer { -public: void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override; }; @@ -359,8 +354,8 @@ struct CNodeOptions class CNode { public: - const std::unique_ptr<TransportDeserializer> m_deserializer; // Used only by SocketHandler thread - const std::unique_ptr<const TransportSerializer> m_serializer; + /** Transport serializer/deserializer. The receive side functions are only called under cs_vRecv. */ + const std::unique_ptr<Transport> m_transport; const NetPermissionFlags m_permission_flags; diff --git a/src/test/fuzz/p2p_transport_serialization.cpp b/src/test/fuzz/p2p_transport_serialization.cpp index 78350a600e..5e44421f1d 100644 --- a/src/test/fuzz/p2p_transport_serialization.cpp +++ b/src/test/fuzz/p2p_transport_serialization.cpp @@ -24,9 +24,10 @@ void initialize_p2p_transport_serialization() FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serialization) { - // Construct deserializer, with a dummy NodeId - V1TransportDeserializer deserializer{Params(), NodeId{0}, SER_NETWORK, INIT_PROTO_VERSION}; - V1TransportSerializer serializer{}; + // Construct transports for both sides, with dummy NodeIds. + V1Transport recv_transport{Params(), NodeId{0}, SER_NETWORK, INIT_PROTO_VERSION}; + V1Transport send_transport{Params(), NodeId{1}, SER_NETWORK, INIT_PROTO_VERSION}; + FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()}; auto checksum_assist = fuzzed_data_provider.ConsumeBool(); @@ -63,14 +64,14 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial mutable_msg_bytes.insert(mutable_msg_bytes.end(), payload_bytes.begin(), payload_bytes.end()); Span<const uint8_t> msg_bytes{mutable_msg_bytes}; while (msg_bytes.size() > 0) { - const int handled = deserializer.Read(msg_bytes); + const int handled = recv_transport.Read(msg_bytes); if (handled < 0) { break; } - if (deserializer.Complete()) { + if (recv_transport.Complete()) { const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()}; bool reject_message{false}; - CNetMessage msg = deserializer.GetMessage(m_time, reject_message); + CNetMessage msg = recv_transport.GetMessage(m_time, reject_message); assert(msg.m_type.size() <= CMessageHeader::COMMAND_SIZE); assert(msg.m_raw_message_size <= mutable_msg_bytes.size()); assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size); @@ -78,7 +79,7 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial std::vector<unsigned char> header; auto msg2 = CNetMsgMaker{msg.m_recv.GetVersion()}.Make(msg.m_type, Span{msg.m_recv}); - serializer.prepareForTransport(msg2, header); + send_transport.prepareForTransport(msg2, header); } } } diff --git a/src/test/util/net.cpp b/src/test/util/net.cpp index 3f72384b3b..0031770028 100644 --- a/src/test/util/net.cpp +++ b/src/test/util/net.cpp @@ -73,7 +73,7 @@ void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span<const uint8_t> msg_by bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const { std::vector<uint8_t> ser_msg_header; - node.m_serializer->prepareForTransport(ser_msg, ser_msg_header); + node.m_transport->prepareForTransport(ser_msg, ser_msg_header); bool complete; NodeReceiveMsgBytes(node, ser_msg_header, complete); |