aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorPieter Wuille <pieter@wuille.net>2023-07-26 13:19:31 -0400
committerPieter Wuille <pieter@wuille.net>2023-08-23 19:56:24 -0400
commit27f9ba23efe82531a465c5e63bf7dc62b6a3a8db (patch)
treedabff76c69490b9d17486e7a255a14121c40438a /src
parent93594e42c3f92d82427d2b284ff0f94cdbebe99c (diff)
downloadbitcoin-27f9ba23efe82531a465c5e63bf7dc62b6a3a8db.tar.xz
net: add V1Transport lock protecting receive state
Rather than relying on the caller to prevent concurrent calls to the various receive-side functions of Transport, introduce a private m_cs_recv inside the implementation to protect the lock state. Of course, this does not remove the need for callers to synchronize calls entirely, as it is a stateful object, and e.g. the order in which Receive(), Complete(), and GetMessage() are called matters. It seems impossible to use a Transport object in a meaningful way in a multi-threaded way without some form of external synchronization, but it still feels safer to make the transport object itself responsible for protecting its internal state.
Diffstat (limited to 'src')
-rw-r--r--src/net.cpp7
-rw-r--r--src/net.h60
2 files changed, 43 insertions, 24 deletions
diff --git a/src/net.cpp b/src/net.cpp
index fa20136bb1..b350c58c61 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -719,6 +719,7 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
{
+ AssertLockHeld(m_recv_mutex);
// copy data to temporary parsing buffer
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
@@ -759,6 +760,7 @@ int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
int V1Transport::readData(Span<const uint8_t> msg_bytes)
{
+ AssertLockHeld(m_recv_mutex);
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
@@ -776,7 +778,8 @@ int V1Transport::readData(Span<const uint8_t> msg_bytes)
const uint256& V1Transport::GetMessageHash() const
{
- assert(Complete());
+ AssertLockHeld(m_recv_mutex);
+ assert(CompleteInternal());
if (data_hash.IsNull())
hasher.Finalize(data_hash);
return data_hash;
@@ -784,9 +787,11 @@ const uint256& V1Transport::GetMessageHash() const
CNetMessage V1Transport::GetMessage(const std::chrono::microseconds time, bool& reject_message)
{
+ AssertLockNotHeld(m_recv_mutex);
// Initialize out parameter
reject_message = false;
// decompose a single CNetMessage from the TransportDeserializer
+ LOCK(m_recv_mutex);
CNetMessage msg(std::move(vRecv));
// store message type string, time, and sizes
diff --git a/src/net.h b/src/net.h
index ca6899a83a..e34ea590cc 100644
--- a/src/net.h
+++ b/src/net.h
@@ -259,8 +259,7 @@ public:
virtual ~Transport() {}
// 1. Receiver side functions, for decoding bytes received on the wire into transport protocol
- // agnostic CNetMessage (message type & payload) objects. Callers must guarantee that none of
- // these functions are called concurrently w.r.t. one another.
+ // agnostic CNetMessage (message type & payload) objects.
// returns true if the current deserialization is complete
virtual bool Complete() const = 0;
@@ -282,20 +281,22 @@ class V1Transport final : public Transport
private:
const CChainParams& m_chain_params;
const NodeId m_node_id; // Only for logging
- mutable CHash256 hasher;
- mutable uint256 data_hash;
- bool in_data; // parsing header (false) or data (true)
- CDataStream hdrbuf; // partially received header
- CMessageHeader hdr; // complete header
- CDataStream vRecv; // received message data
- unsigned int nHdrPos;
- unsigned int nDataPos;
-
- const uint256& GetMessageHash() const;
- int readHeader(Span<const uint8_t> msg_bytes);
- int readData(Span<const uint8_t> msg_bytes);
-
- void Reset() {
+ mutable Mutex m_recv_mutex; //!< Lock for receive state
+ mutable CHash256 hasher GUARDED_BY(m_recv_mutex);
+ mutable uint256 data_hash GUARDED_BY(m_recv_mutex);
+ bool in_data GUARDED_BY(m_recv_mutex); // parsing header (false) or data (true)
+ CDataStream hdrbuf GUARDED_BY(m_recv_mutex); // partially received header
+ CMessageHeader hdr GUARDED_BY(m_recv_mutex); // complete header
+ CDataStream vRecv GUARDED_BY(m_recv_mutex); // received message data
+ unsigned int nHdrPos GUARDED_BY(m_recv_mutex);
+ unsigned int nDataPos GUARDED_BY(m_recv_mutex);
+
+ const uint256& GetMessageHash() const EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
+ int readHeader(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
+ int readData(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex);
+
+ void Reset() EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex) {
+ AssertLockHeld(m_recv_mutex);
vRecv.clear();
hdrbuf.clear();
hdrbuf.resize(24);
@@ -306,6 +307,13 @@ private:
hasher.Reset();
}
+ bool CompleteInternal() const noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex)
+ {
+ AssertLockHeld(m_recv_mutex);
+ if (!in_data) return false;
+ return hdr.nMessageSize == nDataPos;
+ }
+
public:
V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
: m_chain_params(chain_params),
@@ -313,22 +321,28 @@ public:
hdrbuf(nTypeIn, nVersionIn),
vRecv(nTypeIn, nVersionIn)
{
+ LOCK(m_recv_mutex);
Reset();
}
- bool Complete() const override
+ bool Complete() const override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
{
- if (!in_data)
- return false;
- return (hdr.nMessageSize == nDataPos);
+ AssertLockNotHeld(m_recv_mutex);
+ return WITH_LOCK(m_recv_mutex, return CompleteInternal());
}
- void SetVersion(int nVersionIn) override
+
+ void SetVersion(int nVersionIn) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
{
+ AssertLockNotHeld(m_recv_mutex);
+ LOCK(m_recv_mutex);
hdrbuf.SetVersion(nVersionIn);
vRecv.SetVersion(nVersionIn);
}
- int Read(Span<const uint8_t>& msg_bytes) override
+
+ int Read(Span<const uint8_t>& msg_bytes) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex)
{
+ AssertLockNotHeld(m_recv_mutex);
+ LOCK(m_recv_mutex);
int ret = in_data ? readData(msg_bytes) : readHeader(msg_bytes);
if (ret < 0) {
Reset();
@@ -337,7 +351,7 @@ public:
}
return ret;
}
- CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override;
+ CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex);
void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override;
};