diff options
-rw-r--r-- | src/net.cpp | 20 | ||||
-rw-r--r-- | src/net.h | 6 | ||||
-rw-r--r-- | src/net_processing.cpp | 21 | ||||
-rw-r--r-- | src/protocol.cpp | 29 | ||||
-rw-r--r-- | src/protocol.h | 2 | ||||
-rw-r--r-- | src/test/fuzz/deserialize.cpp | 3 | ||||
-rw-r--r-- | src/test/fuzz/p2p_transport_deserializer.cpp | 6 | ||||
-rwxr-xr-x | test/functional/p2p_invalid_messages.py | 7 |
8 files changed, 32 insertions, 62 deletions
diff --git a/src/net.cpp b/src/net.cpp index 941ea3c4ac..633f9a2f7f 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -605,6 +605,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete // absorb network data int handled = m_deserializer->Read(pch, nBytes); if (handled < 0) { + // Serious header problem, disconnect from the peer. return false; } @@ -616,6 +617,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete uint32_t out_err_raw_size{0}; Optional<CNetMessage> result{m_deserializer->GetMessage(time, out_err_raw_size)}; if (!result) { + // Message deserialization failed. Drop the message but don't disconnect the peer. // store the size of the corrupt message mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size; continue; @@ -657,11 +659,19 @@ int V1TransportDeserializer::readHeader(const char *pch, unsigned int nBytes) hdrbuf >> hdr; } catch (const std::exception&) { + LogPrint(BCLog::NET, "HEADER ERROR - UNABLE TO DESERIALIZE, peer=%d\n", m_node_id); + return -1; + } + + // Check start string, network magic + if (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) != 0) { + LogPrint(BCLog::NET, "HEADER ERROR - MESSAGESTART (%s, %u bytes), received %s, peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, HexStr(hdr.pchMessageStart), m_node_id); return -1; } // reject messages larger than MAX_SIZE or MAX_PROTOCOL_MESSAGE_LENGTH if (hdr.nMessageSize > MAX_SIZE || hdr.nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) { + LogPrint(BCLog::NET, "HEADER ERROR - SIZE (%s, %u bytes), peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, m_node_id); return -1; } @@ -701,10 +711,6 @@ Optional<CNetMessage> V1TransportDeserializer::GetMessage(const std::chrono::mic // decompose a single CNetMessage from the TransportDeserializer Optional<CNetMessage> msg(std::move(vRecv)); - // store state about valid header, netmagic and checksum - msg->m_valid_header = hdr.IsValid(m_chain_params.MessageStart()); - msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) == 0); - // store command string, time, and sizes msg->m_command = hdr.GetCommand(); msg->m_time = time; @@ -716,6 +722,7 @@ Optional<CNetMessage> V1TransportDeserializer::GetMessage(const std::chrono::mic // We just received a message off the wire, harvest entropy from the time (and the message checksum) RandAddEvent(ReadLE32(hash.begin())); + // Check checksum and header command string if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0) { LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n", SanitizeString(msg->m_command), msg->m_message_size, @@ -724,6 +731,11 @@ Optional<CNetMessage> V1TransportDeserializer::GetMessage(const std::chrono::mic m_node_id); out_err_raw_size = msg->m_raw_message_size; msg = nullopt; + } else if (!hdr.IsCommandValid()) { + LogPrint(BCLog::NET, "HEADER ERROR - COMMAND (%s, %u bytes), peer=%d\n", + hdr.GetCommand(), msg->m_message_size, m_node_id); + out_err_raw_size = msg->m_raw_message_size; + msg = nullopt; } // Always reset the network deserializer (prepare for the next message) @@ -706,10 +706,8 @@ class CNetMessage { public: CDataStream m_recv; //!< received message data std::chrono::microseconds m_time{0}; //!< time of message receipt - bool m_valid_netmagic = false; - bool m_valid_header = false; - uint32_t m_message_size{0}; //!< size of the payload - uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) + uint32_t m_message_size{0}; //!< size of the payload + uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum) std::string m_command; CNetMessage(CDataStream&& recv_in) : m_recv(std::move(recv_in)) {} diff --git a/src/net_processing.cpp b/src/net_processing.cpp index d9d32cded6..920e7a1abf 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -3820,14 +3820,6 @@ bool PeerManager::MaybeDiscourageAndDisconnect(CNode& pnode) bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgProc) { - // - // Message format - // (4) message start - // (12) command - // (4) size - // (4) checksum - // (x) data - // bool fMoreWork = false; if (!pfrom->vRecvGetData.empty()) @@ -3868,19 +3860,6 @@ bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgP CNetMessage& msg(msgs.front()); msg.SetVersion(pfrom->GetCommonVersion()); - // Check network magic - if (!msg.m_valid_netmagic) { - LogPrint(BCLog::NET, "PROCESSMESSAGE: INVALID MESSAGESTART %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId()); - pfrom->fDisconnect = true; - return false; - } - - // Check header - if (!msg.m_valid_header) - { - LogPrint(BCLog::NET, "PROCESSMESSAGE: ERRORS IN HEADER %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId()); - return fMoreWork; - } const std::string& msg_type = msg.m_command; // Message size diff --git a/src/protocol.cpp b/src/protocol.cpp index 6b4de68ce9..84b6e96aee 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -111,31 +111,20 @@ std::string CMessageHeader::GetCommand() const return std::string(pchCommand, pchCommand + strnlen(pchCommand, COMMAND_SIZE)); } -bool CMessageHeader::IsValid(const MessageStartChars& pchMessageStartIn) const +bool CMessageHeader::IsCommandValid() const { - // Check start string - if (memcmp(pchMessageStart, pchMessageStartIn, MESSAGE_START_SIZE) != 0) - return false; - // Check the command string for errors - for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; p1++) - { - if (*p1 == 0) - { + for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; ++p1) { + if (*p1 == 0) { // Must be all zeros after the first zero - for (; p1 < pchCommand + COMMAND_SIZE; p1++) - if (*p1 != 0) + for (; p1 < pchCommand + COMMAND_SIZE; ++p1) { + if (*p1 != 0) { return false; - } - else if (*p1 < ' ' || *p1 > 0x7E) + } + } + } else if (*p1 < ' ' || *p1 > 0x7E) { return false; - } - - // Message size - if (nMessageSize > MAX_SIZE) - { - LogPrintf("CMessageHeader::IsValid(): (%s, %u bytes) nMessageSize > MAX_SIZE\n", GetCommand(), nMessageSize); - return false; + } } return true; diff --git a/src/protocol.h b/src/protocol.h index 3bf0797ca4..9a44a1626c 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -45,7 +45,7 @@ public: CMessageHeader(const MessageStartChars& pchMessageStartIn, const char* pszCommand, unsigned int nMessageSizeIn); std::string GetCommand() const; - bool IsValid(const MessageStartChars& messageStart) const; + bool IsCommandValid() const; SERIALIZE_METHODS(CMessageHeader, obj) { READWRITE(obj.pchMessageStart, obj.pchCommand, obj.nMessageSize, obj.pchChecksum); } diff --git a/src/test/fuzz/deserialize.cpp b/src/test/fuzz/deserialize.cpp index f87b7576a4..b799d3b43b 100644 --- a/src/test/fuzz/deserialize.cpp +++ b/src/test/fuzz/deserialize.cpp @@ -189,10 +189,9 @@ void test_one_input(const std::vector<uint8_t>& buffer) DeserializeFromFuzzingInput(buffer, s); AssertEqualAfterSerializeDeserialize(s); #elif MESSAGEHEADER_DESERIALIZE - const CMessageHeader::MessageStartChars pchMessageStart = {0x00, 0x00, 0x00, 0x00}; CMessageHeader mh; DeserializeFromFuzzingInput(buffer, mh); - (void)mh.IsValid(pchMessageStart); + (void)mh.IsCommandValid(); #elif ADDRESS_DESERIALIZE CAddress a; DeserializeFromFuzzingInput(buffer, a); diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp index 6252b8e91b..7e216e16fe 100644 --- a/src/test/fuzz/p2p_transport_deserializer.cpp +++ b/src/test/fuzz/p2p_transport_deserializer.cpp @@ -39,12 +39,6 @@ void test_one_input(const std::vector<uint8_t>& buffer) assert(result->m_raw_message_size <= buffer.size()); assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size); assert(result->m_time == m_time); - if (result->m_valid_header) { - assert(result->m_valid_netmagic); - } - if (!result->m_valid_netmagic) { - assert(!result->m_valid_header); - } } } } diff --git a/test/functional/p2p_invalid_messages.py b/test/functional/p2p_invalid_messages.py index fe57057a83..78a9d2e852 100755 --- a/test/functional/p2p_invalid_messages.py +++ b/test/functional/p2p_invalid_messages.py @@ -81,7 +81,7 @@ class InvalidMessagesTest(BitcoinTestFramework): def test_magic_bytes(self): self.log.info("Test message with invalid magic bytes disconnects peer") conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['PROCESSMESSAGE: INVALID MESSAGESTART badmsg']): + with self.nodes[0].assert_debug_log(['HEADER ERROR - MESSAGESTART (badmsg, 2 bytes), received ffffffff']): msg = conn.build_message(msg_unrecognized(str_data="d")) # modify magic bytes msg = b'\xff' * 4 + msg[4:] @@ -105,7 +105,7 @@ class InvalidMessagesTest(BitcoinTestFramework): def test_size(self): self.log.info("Test message with oversized payload disconnects peer") conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['']): + with self.nodes[0].assert_debug_log(['HEADER ERROR - SIZE (badmsg, 4000001 bytes)']): msg = msg_unrecognized(str_data="d" * (VALID_DATA_LIMIT + 1)) msg = conn.build_message(msg) conn.send_raw_message(msg) @@ -115,9 +115,8 @@ class InvalidMessagesTest(BitcoinTestFramework): def test_msgtype(self): self.log.info("Test message with invalid message type logs an error") conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['PROCESSMESSAGE: ERRORS IN HEADER']): + with self.nodes[0].assert_debug_log(['HEADER ERROR - COMMAND']): msg = msg_unrecognized(str_data="d") - msg.msgtype = b'\xff' * 12 msg = conn.build_message(msg) # Modify msgtype msg = msg[:7] + b'\x00' + msg[7 + 1:] |