aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main.cpp65
-rw-r--r--src/net.cpp96
-rw-r--r--src/net.h70
3 files changed, 176 insertions, 55 deletions
diff --git a/src/main.cpp b/src/main.cpp
index 22baf0f3eb..3406144595 100644
--- a/src/main.cpp
+++ b/src/main.cpp
@@ -3168,7 +3168,7 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
else if (strCommand == "verack")
{
- pfrom->vRecv.SetVersion(min(pfrom->nVersion, PROTOCOL_VERSION));
+ pfrom->SetRecvVersion(min(pfrom->nVersion, PROTOCOL_VERSION));
}
@@ -3705,13 +3705,13 @@ bool static ProcessMessage(CNode* pfrom, string strCommand, CDataStream& vRecv)
return true;
}
+// requires LOCK(cs_vRecvMsg)
bool ProcessMessages(CNode* pfrom)
{
- CDataStream& vRecv = pfrom->vRecv;
- if (vRecv.empty())
+ if (pfrom->vRecvMsg.empty())
return true;
//if (fDebug)
- // printf("ProcessMessages(%u bytes)\n", vRecv.size());
+ // printf("ProcessMessages(%zu messages)\n", pfrom->vRecvMsg.size());
//
// Message format
@@ -3722,32 +3722,32 @@ bool ProcessMessages(CNode* pfrom)
// (x) data
//
- loop
+ unsigned int nMsgPos = 0;
+ for (; nMsgPos < pfrom->vRecvMsg.size(); nMsgPos++)
{
// Don't bother if send buffer is too full to respond anyway
if (pfrom->vSend.size() >= SendBufferSize())
break;
- // Scan for message start
- CDataStream::iterator pstart = search(vRecv.begin(), vRecv.end(), BEGIN(pchMessageStart), END(pchMessageStart));
- int nHeaderSize = vRecv.GetSerializeSize(CMessageHeader());
- if (vRecv.end() - pstart < nHeaderSize)
- {
- if ((int)vRecv.size() > nHeaderSize)
- {
- printf("\n\nPROCESSMESSAGE MESSAGESTART NOT FOUND\n\n");
- vRecv.erase(vRecv.begin(), vRecv.end() - nHeaderSize);
- }
+ // get next message; end, if an incomplete message is found
+ CNetMessage& msg = pfrom->vRecvMsg[nMsgPos];
+
+ //if (fDebug)
+ // printf("ProcessMessages(message %u msgsz, %zu bytes, complete:%s)\n",
+ // msg.hdr.nMessageSize, msg.vRecv.size(),
+ // msg.complete() ? "Y" : "N");
+
+ if (!msg.complete())
break;
+
+ // Scan for message start
+ if (memcmp(msg.hdr.pchMessageStart, pchMessageStart, sizeof(pchMessageStart)) != 0) {
+ printf("\n\nPROCESSMESSAGE: INVALID MESSAGESTART\n\n");
+ return false;
}
- if (pstart - vRecv.begin() > 0)
- printf("\n\nPROCESSMESSAGE SKIPPED %"PRIpdd" BYTES\n\n", pstart - vRecv.begin());
- vRecv.erase(vRecv.begin(), pstart);
// Read header
- vector<char> vHeaderSave(vRecv.begin(), vRecv.begin() + nHeaderSize);
- CMessageHeader hdr;
- vRecv >> hdr;
+ CMessageHeader& hdr = msg.hdr;
if (!hdr.IsValid())
{
printf("\n\nPROCESSMESSAGE: ERRORS IN HEADER %s\n\n\n", hdr.GetCommand().c_str());
@@ -3757,19 +3757,9 @@ bool ProcessMessages(CNode* pfrom)
// Message size
unsigned int nMessageSize = hdr.nMessageSize;
- if (nMessageSize > MAX_SIZE)
- {
- printf("ProcessMessages(%s, %u bytes) : nMessageSize > MAX_SIZE\n", strCommand.c_str(), nMessageSize);
- continue;
- }
- if (nMessageSize > vRecv.size())
- {
- // Rewind and wait for rest of message
- vRecv.insert(vRecv.begin(), vHeaderSave.begin(), vHeaderSave.end());
- break;
- }
// Checksum
+ CDataStream& vRecv = msg.vRecv;
uint256 hash = Hash(vRecv.begin(), vRecv.begin() + nMessageSize);
unsigned int nChecksum = 0;
memcpy(&nChecksum, &hash, sizeof(nChecksum));
@@ -3780,17 +3770,13 @@ bool ProcessMessages(CNode* pfrom)
continue;
}
- // Copy message to its own buffer
- CDataStream vMsg(vRecv.begin(), vRecv.begin() + nMessageSize, vRecv.nType, vRecv.nVersion);
- vRecv.ignore(nMessageSize);
-
// Process message
bool fRet = false;
try
{
{
LOCK(cs_main);
- fRet = ProcessMessage(pfrom, strCommand, vMsg);
+ fRet = ProcessMessage(pfrom, strCommand, vRecv);
}
if (fShutdown)
return true;
@@ -3822,7 +3808,10 @@ bool ProcessMessages(CNode* pfrom)
printf("ProcessMessage(%s, %u bytes) FAILED\n", strCommand.c_str(), nMessageSize);
}
- vRecv.Compact();
+ // remove processed messages; one incomplete message may remain
+ if (nMsgPos > 0)
+ pfrom->vRecvMsg.erase(pfrom->vRecvMsg.begin(),
+ pfrom->vRecvMsg.begin() + nMsgPos);
return true;
}
diff --git a/src/net.cpp b/src/net.cpp
index 6c8fe3ffc9..0e558228d7 100644
--- a/src/net.cpp
+++ b/src/net.cpp
@@ -536,7 +536,7 @@ void CNode::CloseSocketDisconnect()
printf("disconnecting node %s\n", addrName.c_str());
closesocket(hSocket);
hSocket = INVALID_SOCKET;
- vRecv.clear();
+ vRecvMsg.clear();
}
}
@@ -628,6 +628,78 @@ void CNode::copyStats(CNodeStats &stats)
}
#undef X
+// requires LOCK(cs_vRecvMsg)
+bool CNode::ReceiveMsgBytes(const char *pch, unsigned int nBytes)
+{
+ while (nBytes > 0) {
+
+ // get current incomplete message, or create a new one
+ if (vRecvMsg.size() == 0 ||
+ vRecvMsg.back().complete())
+ vRecvMsg.push_back(CNetMessage(SER_NETWORK, nRecvVersion));
+
+ CNetMessage& msg = vRecvMsg.back();
+
+ // absorb network data
+ int handled;
+ if (!msg.in_data)
+ handled = msg.readHeader(pch, nBytes);
+ else
+ handled = msg.readData(pch, nBytes);
+
+ if (handled < 0)
+ return false;
+
+ pch += handled;
+ nBytes -= handled;
+ }
+
+ return true;
+}
+
+int CNetMessage::readHeader(const char *pch, unsigned int nBytes)
+{
+ // copy data to temporary parsing buffer
+ unsigned int nRemaining = 24 - nHdrPos;
+ unsigned int nCopy = std::min(nRemaining, nBytes);
+
+ memcpy(&hdrbuf[nHdrPos], pch, nCopy);
+ nHdrPos += nCopy;
+
+ // if header incomplete, exit
+ if (nHdrPos < 24)
+ return nCopy;
+
+ // deserialize to CMessageHeader
+ try {
+ hdrbuf >> hdr;
+ }
+ catch (std::exception &e) {
+ return -1;
+ }
+
+ // reject messages larger than MAX_SIZE
+ if (hdr.nMessageSize > MAX_SIZE)
+ return -1;
+
+ // switch state to reading message data
+ in_data = true;
+ vRecv.resize(hdr.nMessageSize);
+
+ return nCopy;
+}
+
+int CNetMessage::readData(const char *pch, unsigned int nBytes)
+{
+ unsigned int nRemaining = hdr.nMessageSize - nDataPos;
+ unsigned int nCopy = std::min(nRemaining, nBytes);
+
+ memcpy(&vRecv[nDataPos], pch, nCopy);
+ nDataPos += nCopy;
+
+ return nCopy;
+}
+
@@ -676,7 +748,7 @@ void ThreadSocketHandler2(void* parg)
BOOST_FOREACH(CNode* pnode, vNodesCopy)
{
if (pnode->fDisconnect ||
- (pnode->GetRefCount() <= 0 && pnode->vRecv.empty() && pnode->vSend.empty()))
+ (pnode->GetRefCount() <= 0 && pnode->vRecvMsg.empty() && pnode->vSend.empty()))
{
// remove from vNodes
vNodes.erase(remove(vNodes.begin(), vNodes.end(), pnode), vNodes.end());
@@ -708,7 +780,7 @@ void ThreadSocketHandler2(void* parg)
TRY_LOCK(pnode->cs_vSend, lockSend);
if (lockSend)
{
- TRY_LOCK(pnode->cs_vRecv, lockRecv);
+ TRY_LOCK(pnode->cs_vRecvMsg, lockRecv);
if (lockRecv)
{
TRY_LOCK(pnode->cs_inventory, lockInv);
@@ -873,15 +945,12 @@ void ThreadSocketHandler2(void* parg)
continue;
if (FD_ISSET(pnode->hSocket, &fdsetRecv) || FD_ISSET(pnode->hSocket, &fdsetError))
{
- TRY_LOCK(pnode->cs_vRecv, lockRecv);
+ TRY_LOCK(pnode->cs_vRecvMsg, lockRecv);
if (lockRecv)
{
- CDataStream& vRecv = pnode->vRecv;
- unsigned int nPos = vRecv.size();
-
- if (nPos > ReceiveBufferSize()) {
+ if (pnode->GetTotalRecvSize() > ReceiveFloodSize()) {
if (!pnode->fDisconnect)
- printf("socket recv flood control disconnect (%"PRIszu" bytes)\n", vRecv.size());
+ printf("socket recv flood control disconnect (%u bytes)\n", pnode->GetTotalRecvSize());
pnode->CloseSocketDisconnect();
}
else {
@@ -890,8 +959,8 @@ void ThreadSocketHandler2(void* parg)
int nBytes = recv(pnode->hSocket, pchBuf, sizeof(pchBuf), MSG_DONTWAIT);
if (nBytes > 0)
{
- vRecv.resize(nPos + nBytes);
- memcpy(&vRecv[nPos], pchBuf, nBytes);
+ if (!pnode->ReceiveMsgBytes(pchBuf, nBytes))
+ pnode->CloseSocketDisconnect();
pnode->nLastRecv = GetTime();
}
else if (nBytes == 0)
@@ -1693,9 +1762,10 @@ void ThreadMessageHandler2(void* parg)
{
// Receive messages
{
- TRY_LOCK(pnode->cs_vRecv, lockRecv);
+ TRY_LOCK(pnode->cs_vRecvMsg, lockRecv);
if (lockRecv)
- ProcessMessages(pnode);
+ if (!ProcessMessages(pnode))
+ pnode->CloseSocketDisconnect();
}
if (fShutdown)
return;
diff --git a/src/net.h b/src/net.h
index 3b46523cd9..78f8e72fb0 100644
--- a/src/net.h
+++ b/src/net.h
@@ -27,7 +27,7 @@ extern int nBestHeight;
-inline unsigned int ReceiveBufferSize() { return 1000*GetArg("-maxreceivebuffer", 5*1000); }
+inline unsigned int ReceiveFloodSize() { return 1000*GetArg("-maxreceivebuffer", 5*1000); }
inline unsigned int SendBufferSize() { return 1000*GetArg("-maxsendbuffer", 1*1000); }
void AddOneShot(std::string strDest);
@@ -126,6 +126,44 @@ public:
+class CNetMessage {
+public:
+ bool in_data; // parsing header (false) or data (true)
+
+ CDataStream hdrbuf; // partially received header
+ CMessageHeader hdr; // complete header
+ unsigned int nHdrPos;
+
+ CDataStream vRecv; // received message data
+ unsigned int nDataPos;
+
+ CNetMessage(int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), vRecv(nTypeIn, nVersionIn) {
+ hdrbuf.resize(24);
+ in_data = false;
+ nHdrPos = 0;
+ nDataPos = 0;
+ }
+
+ bool complete() const
+ {
+ if (!in_data)
+ return false;
+ return (hdr.nMessageSize == nDataPos);
+ }
+
+ void SetVersion(int nVersionIn)
+ {
+ hdrbuf.SetVersion(nVersionIn);
+ vRecv.SetVersion(nVersionIn);
+ }
+
+ int readHeader(const char *pch, unsigned int nBytes);
+ int readData(const char *pch, unsigned int nBytes);
+};
+
+
+
+
/** Information about a peer */
class CNode
@@ -135,9 +173,12 @@ public:
uint64 nServices;
SOCKET hSocket;
CDataStream vSend;
- CDataStream vRecv;
CCriticalSection cs_vSend;
- CCriticalSection cs_vRecv;
+
+ std::vector<CNetMessage> vRecvMsg;
+ CCriticalSection cs_vRecvMsg;
+ int nRecvVersion;
+
int64 nLastSend;
int64 nLastRecv;
int64 nLastSendEmpty;
@@ -191,10 +232,11 @@ public:
CCriticalSection cs_inventory;
std::multimap<int64, CInv> mapAskFor;
- CNode(SOCKET hSocketIn, CAddress addrIn, std::string addrNameIn = "", bool fInboundIn=false) : vSend(SER_NETWORK, MIN_PROTO_VERSION), vRecv(SER_NETWORK, MIN_PROTO_VERSION)
+ CNode(SOCKET hSocketIn, CAddress addrIn, std::string addrNameIn = "", bool fInboundIn=false) : vSend(SER_NETWORK, MIN_PROTO_VERSION)
{
nServices = 0;
hSocket = hSocketIn;
+ nRecvVersion = MIN_PROTO_VERSION;
nLastSend = 0;
nLastRecv = 0;
nLastSendEmpty = GetTime();
@@ -250,6 +292,26 @@ public:
return std::max(nRefCount, 0) + (GetTime() < nReleaseTime ? 1 : 0);
}
+ // requires LOCK(cs_vRecvMsg)
+ unsigned int GetTotalRecvSize()
+ {
+ unsigned int total = 0;
+ for (unsigned int i = 0; i < vRecvMsg.size(); i++)
+ total += vRecvMsg[i].vRecv.size();
+ return total;
+ }
+
+ // requires LOCK(cs_vRecvMsg)
+ bool ReceiveMsgBytes(const char *pch, unsigned int nBytes);
+
+ // requires LOCK(cs_vRecvMsg)
+ void SetRecvVersion(int nVersionIn)
+ {
+ nRecvVersion = nVersionIn;
+ for (unsigned int i = 0; i < vRecvMsg.size(); i++)
+ vRecvMsg[i].SetVersion(nVersionIn);
+ }
+
CNode* AddRef(int64 nTimeout=0)
{
if (nTimeout != 0)