aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/net.cpp20
-rw-r--r--src/net.h6
-rw-r--r--src/net_processing.cpp21
-rw-r--r--src/protocol.cpp29
-rw-r--r--src/protocol.h2
-rw-r--r--src/test/fuzz/deserialize.cpp3
-rw-r--r--src/test/fuzz/p2p_transport_deserializer.cpp6
-rwxr-xr-xtest/functional/p2p_invalid_messages.py7
8 files changed, 32 insertions, 62 deletions
diff --git a/src/net.cpp b/src/net.cpp
index 941ea3c4ac..633f9a2f7f 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -605,6 +605,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
// absorb network data
int handled = m_deserializer->Read(pch, nBytes);
if (handled < 0) {
+ // Serious header problem, disconnect from the peer.
return false;
}
@@ -616,6 +617,7 @@ bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes, bool& complete
uint32_t out_err_raw_size{0};
Optional<CNetMessage> result{m_deserializer->GetMessage(time, out_err_raw_size)};
if (!result) {
+ // Message deserialization failed. Drop the message but don't disconnect the peer.
// store the size of the corrupt message
mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += out_err_raw_size;
continue;
@@ -657,11 +659,19 @@ int V1TransportDeserializer::readHeader(const char *pch, unsigned int nBytes)
hdrbuf >> hdr;
}
catch (const std::exception&) {
+ LogPrint(BCLog::NET, "HEADER ERROR - UNABLE TO DESERIALIZE, peer=%d\n", m_node_id);
+ return -1;
+ }
+
+ // Check start string, network magic
+ if (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) != 0) {
+ LogPrint(BCLog::NET, "HEADER ERROR - MESSAGESTART (%s, %u bytes), received %s, peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, HexStr(hdr.pchMessageStart), m_node_id);
return -1;
}
// reject messages larger than MAX_SIZE or MAX_PROTOCOL_MESSAGE_LENGTH
if (hdr.nMessageSize > MAX_SIZE || hdr.nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) {
+ LogPrint(BCLog::NET, "HEADER ERROR - SIZE (%s, %u bytes), peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, m_node_id);
return -1;
}
@@ -701,10 +711,6 @@ Optional<CNetMessage> V1TransportDeserializer::GetMessage(const std::chrono::mic
// decompose a single CNetMessage from the TransportDeserializer
Optional<CNetMessage> msg(std::move(vRecv));
- // store state about valid header, netmagic and checksum
- msg->m_valid_header = hdr.IsValid(m_chain_params.MessageStart());
- msg->m_valid_netmagic = (memcmp(hdr.pchMessageStart, m_chain_params.MessageStart(), CMessageHeader::MESSAGE_START_SIZE) == 0);
-
// store command string, time, and sizes
msg->m_command = hdr.GetCommand();
msg->m_time = time;
@@ -716,6 +722,7 @@ Optional<CNetMessage> V1TransportDeserializer::GetMessage(const std::chrono::mic
// We just received a message off the wire, harvest entropy from the time (and the message checksum)
RandAddEvent(ReadLE32(hash.begin()));
+ // Check checksum and header command string
if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != 0) {
LogPrint(BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n",
SanitizeString(msg->m_command), msg->m_message_size,
@@ -724,6 +731,11 @@ Optional<CNetMessage> V1TransportDeserializer::GetMessage(const std::chrono::mic
m_node_id);
out_err_raw_size = msg->m_raw_message_size;
msg = nullopt;
+ } else if (!hdr.IsCommandValid()) {
+ LogPrint(BCLog::NET, "HEADER ERROR - COMMAND (%s, %u bytes), peer=%d\n",
+ hdr.GetCommand(), msg->m_message_size, m_node_id);
+ out_err_raw_size = msg->m_raw_message_size;
+ msg = nullopt;
}
// Always reset the network deserializer (prepare for the next message)
diff --git a/src/net.h b/src/net.h
index 29941b9622..9a92f80511 100644
--- a/src/net.h
+++ b/src/net.h
@@ -706,10 +706,8 @@ class CNetMessage {
public:
CDataStream m_recv; //!< received message data
std::chrono::microseconds m_time{0}; //!< time of message receipt
- bool m_valid_netmagic = false;
- bool m_valid_header = false;
- uint32_t m_message_size{0}; //!< size of the payload
- uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
+ uint32_t m_message_size{0}; //!< size of the payload
+ uint32_t m_raw_message_size{0}; //!< used wire size of the message (including header/checksum)
std::string m_command;
CNetMessage(CDataStream&& recv_in) : m_recv(std::move(recv_in)) {}
diff --git a/src/net_processing.cpp b/src/net_processing.cpp
index d9d32cded6..920e7a1abf 100644
--- a/src/net_processing.cpp
+++ b/src/net_processing.cpp
@@ -3820,14 +3820,6 @@ bool PeerManager::MaybeDiscourageAndDisconnect(CNode& pnode)
bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgProc)
{
- //
- // Message format
- // (4) message start
- // (12) command
- // (4) size
- // (4) checksum
- // (x) data
- //
bool fMoreWork = false;
if (!pfrom->vRecvGetData.empty())
@@ -3868,19 +3860,6 @@ bool PeerManager::ProcessMessages(CNode* pfrom, std::atomic<bool>& interruptMsgP
CNetMessage& msg(msgs.front());
msg.SetVersion(pfrom->GetCommonVersion());
- // Check network magic
- if (!msg.m_valid_netmagic) {
- LogPrint(BCLog::NET, "PROCESSMESSAGE: INVALID MESSAGESTART %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId());
- pfrom->fDisconnect = true;
- return false;
- }
-
- // Check header
- if (!msg.m_valid_header)
- {
- LogPrint(BCLog::NET, "PROCESSMESSAGE: ERRORS IN HEADER %s peer=%d\n", SanitizeString(msg.m_command), pfrom->GetId());
- return fMoreWork;
- }
const std::string& msg_type = msg.m_command;
// Message size
diff --git a/src/protocol.cpp b/src/protocol.cpp
index 6b4de68ce9..84b6e96aee 100644
--- a/src/protocol.cpp
+++ b/src/protocol.cpp
@@ -111,31 +111,20 @@ std::string CMessageHeader::GetCommand() const
return std::string(pchCommand, pchCommand + strnlen(pchCommand, COMMAND_SIZE));
}
-bool CMessageHeader::IsValid(const MessageStartChars& pchMessageStartIn) const
+bool CMessageHeader::IsCommandValid() const
{
- // Check start string
- if (memcmp(pchMessageStart, pchMessageStartIn, MESSAGE_START_SIZE) != 0)
- return false;
-
// Check the command string for errors
- for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; p1++)
- {
- if (*p1 == 0)
- {
+ for (const char* p1 = pchCommand; p1 < pchCommand + COMMAND_SIZE; ++p1) {
+ if (*p1 == 0) {
// Must be all zeros after the first zero
- for (; p1 < pchCommand + COMMAND_SIZE; p1++)
- if (*p1 != 0)
+ for (; p1 < pchCommand + COMMAND_SIZE; ++p1) {
+ if (*p1 != 0) {
return false;
- }
- else if (*p1 < ' ' || *p1 > 0x7E)
+ }
+ }
+ } else if (*p1 < ' ' || *p1 > 0x7E) {
return false;
- }
-
- // Message size
- if (nMessageSize > MAX_SIZE)
- {
- LogPrintf("CMessageHeader::IsValid(): (%s, %u bytes) nMessageSize > MAX_SIZE\n", GetCommand(), nMessageSize);
- return false;
+ }
}
return true;
diff --git a/src/protocol.h b/src/protocol.h
index 3bf0797ca4..9a44a1626c 100644
--- a/src/protocol.h
+++ b/src/protocol.h
@@ -45,7 +45,7 @@ public:
CMessageHeader(const MessageStartChars& pchMessageStartIn, const char* pszCommand, unsigned int nMessageSizeIn);
std::string GetCommand() const;
- bool IsValid(const MessageStartChars& messageStart) const;
+ bool IsCommandValid() const;
SERIALIZE_METHODS(CMessageHeader, obj) { READWRITE(obj.pchMessageStart, obj.pchCommand, obj.nMessageSize, obj.pchChecksum); }
diff --git a/src/test/fuzz/deserialize.cpp b/src/test/fuzz/deserialize.cpp
index f87b7576a4..b799d3b43b 100644
--- a/src/test/fuzz/deserialize.cpp
+++ b/src/test/fuzz/deserialize.cpp
@@ -189,10 +189,9 @@ void test_one_input(const std::vector<uint8_t>& buffer)
DeserializeFromFuzzingInput(buffer, s);
AssertEqualAfterSerializeDeserialize(s);
#elif MESSAGEHEADER_DESERIALIZE
- const CMessageHeader::MessageStartChars pchMessageStart = {0x00, 0x00, 0x00, 0x00};
CMessageHeader mh;
DeserializeFromFuzzingInput(buffer, mh);
- (void)mh.IsValid(pchMessageStart);
+ (void)mh.IsCommandValid();
#elif ADDRESS_DESERIALIZE
CAddress a;
DeserializeFromFuzzingInput(buffer, a);
diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp
index 6252b8e91b..7e216e16fe 100644
--- a/src/test/fuzz/p2p_transport_deserializer.cpp
+++ b/src/test/fuzz/p2p_transport_deserializer.cpp
@@ -39,12 +39,6 @@ void test_one_input(const std::vector<uint8_t>& buffer)
assert(result->m_raw_message_size <= buffer.size());
assert(result->m_raw_message_size == CMessageHeader::HEADER_SIZE + result->m_message_size);
assert(result->m_time == m_time);
- if (result->m_valid_header) {
- assert(result->m_valid_netmagic);
- }
- if (!result->m_valid_netmagic) {
- assert(!result->m_valid_header);
- }
}
}
}
diff --git a/test/functional/p2p_invalid_messages.py b/test/functional/p2p_invalid_messages.py
index fe57057a83..78a9d2e852 100755
--- a/test/functional/p2p_invalid_messages.py
+++ b/test/functional/p2p_invalid_messages.py
@@ -81,7 +81,7 @@ class InvalidMessagesTest(BitcoinTestFramework):
def test_magic_bytes(self):
self.log.info("Test message with invalid magic bytes disconnects peer")
conn = self.nodes[0].add_p2p_connection(P2PDataStore())
- with self.nodes[0].assert_debug_log(['PROCESSMESSAGE: INVALID MESSAGESTART badmsg']):
+ with self.nodes[0].assert_debug_log(['HEADER ERROR - MESSAGESTART (badmsg, 2 bytes), received ffffffff']):
msg = conn.build_message(msg_unrecognized(str_data="d"))
# modify magic bytes
msg = b'\xff' * 4 + msg[4:]
@@ -105,7 +105,7 @@ class InvalidMessagesTest(BitcoinTestFramework):
def test_size(self):
self.log.info("Test message with oversized payload disconnects peer")
conn = self.nodes[0].add_p2p_connection(P2PDataStore())
- with self.nodes[0].assert_debug_log(['']):
+ with self.nodes[0].assert_debug_log(['HEADER ERROR - SIZE (badmsg, 4000001 bytes)']):
msg = msg_unrecognized(str_data="d" * (VALID_DATA_LIMIT + 1))
msg = conn.build_message(msg)
conn.send_raw_message(msg)
@@ -115,9 +115,8 @@ class InvalidMessagesTest(BitcoinTestFramework):
def test_msgtype(self):
self.log.info("Test message with invalid message type logs an error")
conn = self.nodes[0].add_p2p_connection(P2PDataStore())
- with self.nodes[0].assert_debug_log(['PROCESSMESSAGE: ERRORS IN HEADER']):
+ with self.nodes[0].assert_debug_log(['HEADER ERROR - COMMAND']):
msg = msg_unrecognized(str_data="d")
- msg.msgtype = b'\xff' * 12
msg = conn.build_message(msg)
# Modify msgtype
msg = msg[:7] + b'\x00' + msg[7 + 1:]