diff options
author | Pieter Wuille <pieter@wuille.net> | 2023-07-26 13:19:31 -0400 |
---|---|---|
committer | Pieter Wuille <pieter@wuille.net> | 2023-08-23 19:56:24 -0400 |
commit | 27f9ba23efe82531a465c5e63bf7dc62b6a3a8db (patch) | |
tree | dabff76c69490b9d17486e7a255a14121c40438a /src/net.h | |
parent | 93594e42c3f92d82427d2b284ff0f94cdbebe99c (diff) |
net: add V1Transport lock protecting receive state
Rather than relying on the caller to prevent concurrent calls to the
various receive-side functions of Transport, introduce a private m_cs_recv
inside the implementation to protect the lock state.
Of course, this does not remove the need for callers to synchronize calls
entirely, as it is a stateful object, and e.g. the order in which Receive(),
Complete(), and GetMessage() are called matters. It seems impossible to use
a Transport object in a meaningful way in a multi-threaded way without some
form of external synchronization, but it still feels safer to make the
transport object itself responsible for protecting its internal state.
Diffstat (limited to 'src/net.h')
-rw-r--r-- | src/net.h | 60 |
1 files changed, 37 insertions, 23 deletions
@@ -259,8 +259,7 @@ public: virtual ~Transport() {} // 1. Receiver side functions, for decoding bytes received on the wire into transport protocol - // agnostic CNetMessage (message type & payload) objects. Callers must guarantee that none of - // these functions are called concurrently w.r.t. one another. + // agnostic CNetMessage (message type & payload) objects. // returns true if the current deserialization is complete virtual bool Complete() const = 0; @@ -282,20 +281,22 @@ class V1Transport final : public Transport private: const CChainParams& m_chain_params; const NodeId m_node_id; // Only for logging - mutable CHash256 hasher; - mutable uint256 data_hash; - bool in_data; // parsing header (false) or data (true) - CDataStream hdrbuf; // partially received header - CMessageHeader hdr; // complete header - CDataStream vRecv; // received message data - unsigned int nHdrPos; - unsigned int nDataPos; - - const uint256& GetMessageHash() const; - int readHeader(Span<const uint8_t> msg_bytes); - int readData(Span<const uint8_t> msg_bytes); - - void Reset() { + mutable Mutex m_recv_mutex; //!< Lock for receive state + mutable CHash256 hasher GUARDED_BY(m_recv_mutex); + mutable uint256 data_hash GUARDED_BY(m_recv_mutex); + bool in_data GUARDED_BY(m_recv_mutex); // parsing header (false) or data (true) + CDataStream hdrbuf GUARDED_BY(m_recv_mutex); // partially received header + CMessageHeader hdr GUARDED_BY(m_recv_mutex); // complete header + CDataStream vRecv GUARDED_BY(m_recv_mutex); // received message data + unsigned int nHdrPos GUARDED_BY(m_recv_mutex); + unsigned int nDataPos GUARDED_BY(m_recv_mutex); + + const uint256& GetMessageHash() const EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + int readHeader(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + int readData(Span<const uint8_t> msg_bytes) EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex); + + void Reset() EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex) { + AssertLockHeld(m_recv_mutex); vRecv.clear(); hdrbuf.clear(); hdrbuf.resize(24); @@ -306,6 +307,13 @@ private: hasher.Reset(); } + bool CompleteInternal() const noexcept EXCLUSIVE_LOCKS_REQUIRED(m_recv_mutex) + { + AssertLockHeld(m_recv_mutex); + if (!in_data) return false; + return hdr.nMessageSize == nDataPos; + } + public: V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn) : m_chain_params(chain_params), @@ -313,22 +321,28 @@ public: hdrbuf(nTypeIn, nVersionIn), vRecv(nTypeIn, nVersionIn) { + LOCK(m_recv_mutex); Reset(); } - bool Complete() const override + bool Complete() const override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex) { - if (!in_data) - return false; - return (hdr.nMessageSize == nDataPos); + AssertLockNotHeld(m_recv_mutex); + return WITH_LOCK(m_recv_mutex, return CompleteInternal()); } - void SetVersion(int nVersionIn) override + + void SetVersion(int nVersionIn) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex) { + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); hdrbuf.SetVersion(nVersionIn); vRecv.SetVersion(nVersionIn); } - int Read(Span<const uint8_t>& msg_bytes) override + + int Read(Span<const uint8_t>& msg_bytes) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex) { + AssertLockNotHeld(m_recv_mutex); + LOCK(m_recv_mutex); int ret = in_data ? readData(msg_bytes) : readHeader(msg_bytes); if (ret < 0) { Reset(); @@ -337,7 +351,7 @@ public: } return ret; } - CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override; + CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override EXCLUSIVE_LOCKS_REQUIRED(!m_recv_mutex); void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override; }; |