diff options
-rw-r--r-- | src/merkleblock.h | 6 | ||||
-rw-r--r-- | src/rpc/rawtransaction.cpp | 13 | ||||
-rwxr-xr-x | test/functional/rpc_txoutproof.py | 23 | ||||
-rwxr-xr-x | test/functional/test_framework/messages.py | 46 |
4 files changed, 84 insertions, 4 deletions
diff --git a/src/merkleblock.h b/src/merkleblock.h index 0976e21c3a..984e33a961 100644 --- a/src/merkleblock.h +++ b/src/merkleblock.h @@ -115,6 +115,12 @@ public: * returns the merkle root, or 0 in case of failure */ uint256 ExtractMatches(std::vector<uint256> &vMatch, std::vector<unsigned int> &vnIndex); + + /** Get number of transactions the merkle proof is indicating for cross-reference with + * local blockchain knowledge. + */ + unsigned int GetNumTransactions() const { return nTransactions; }; + }; diff --git a/src/rpc/rawtransaction.cpp b/src/rpc/rawtransaction.cpp index 63548bff05..3e06b05aca 100644 --- a/src/rpc/rawtransaction.cpp +++ b/src/rpc/rawtransaction.cpp @@ -306,7 +306,7 @@ static UniValue verifytxoutproof(const JSONRPCRequest& request) "\nArguments:\n" "1. \"proof\" (string, required) The hex-encoded proof generated by gettxoutproof\n" "\nResult:\n" - "[\"txid\"] (array, strings) The txid(s) which the proof commits to, or empty array if the proof is invalid\n" + "[\"txid\"] (array, strings) The txid(s) which the proof commits to, or empty array if the proof can not be validated.\n" ); CDataStream ssMB(ParseHexV(request.params[0], "proof"), SER_NETWORK, PROTOCOL_VERSION | SERIALIZE_TRANSACTION_NO_WITNESS); @@ -323,12 +323,17 @@ static UniValue verifytxoutproof(const JSONRPCRequest& request) LOCK(cs_main); const CBlockIndex* pindex = LookupBlockIndex(merkleBlock.header.GetHash()); - if (!pindex || !chainActive.Contains(pindex)) { + if (!pindex || !chainActive.Contains(pindex) || pindex->nTx == 0) { throw JSONRPCError(RPC_INVALID_ADDRESS_OR_KEY, "Block not found in chain"); } - for (const uint256& hash : vMatch) - res.push_back(hash.GetHex()); + // Check if proof is valid, only add results if so + if (pindex->nTx == merkleBlock.txn.GetNumTransactions()) { + for (const uint256& hash : vMatch) { + res.push_back(hash.GetHex()); + } + } + return res; } diff --git a/test/functional/rpc_txoutproof.py b/test/functional/rpc_txoutproof.py index c52a7397dc..e5a63f0c46 100755 --- a/test/functional/rpc_txoutproof.py +++ b/test/functional/rpc_txoutproof.py @@ -6,6 +6,8 @@ from test_framework.test_framework import BitcoinTestFramework from test_framework.util import * +from test_framework.mininode import FromHex, ToHex +from test_framework.messages import CMerkleBlock class MerkleBlockTest(BitcoinTestFramework): def set_test_params(self): @@ -78,6 +80,27 @@ class MerkleBlockTest(BitcoinTestFramework): # We can't get a proof if we specify transactions from different blocks assert_raises_rpc_error(-5, "Not all transactions found in specified or retrieved block", self.nodes[2].gettxoutproof, [txid1, txid3]) + # Now we'll try tweaking a proof. + proof = self.nodes[3].gettxoutproof([txid1, txid2]) + assert txid1 in self.nodes[0].verifytxoutproof(proof) + assert txid2 in self.nodes[1].verifytxoutproof(proof) + + tweaked_proof = FromHex(CMerkleBlock(), proof) + + # Make sure that our serialization/deserialization is working + assert txid1 in self.nodes[2].verifytxoutproof(ToHex(tweaked_proof)) + + # Check to see if we can go up the merkle tree and pass this off as a + # single-transaction block + tweaked_proof.txn.nTransactions = 1 + tweaked_proof.txn.vHash = [tweaked_proof.header.hashMerkleRoot] + tweaked_proof.txn.vBits = [True] + [False]*7 + + for n in self.nodes: + assert not n.verifytxoutproof(ToHex(tweaked_proof)) + + # TODO: try more variants, eg transactions at different depths, and + # verify that the proofs are invalid if __name__ == '__main__': MerkleBlockTest().main() diff --git a/test/functional/test_framework/messages.py b/test/functional/test_framework/messages.py index ca2e425bd6..df8d424d01 100755 --- a/test/functional/test_framework/messages.py +++ b/test/functional/test_framework/messages.py @@ -841,6 +841,52 @@ class BlockTransactions(): def __repr__(self): return "BlockTransactions(hash=%064x transactions=%s)" % (self.blockhash, repr(self.transactions)) +class CPartialMerkleTree(): + def __init__(self): + self.nTransactions = 0 + self.vHash = [] + self.vBits = [] + self.fBad = False + + def deserialize(self, f): + self.nTransactions = struct.unpack("<i", f.read(4))[0] + self.vHash = deser_uint256_vector(f) + vBytes = deser_string(f) + self.vBits = [] + for i in range(len(vBytes) * 8): + self.vBits.append(vBytes[i//8] & (1 << (i % 8)) != 0) + + def serialize(self): + r = b"" + r += struct.pack("<i", self.nTransactions) + r += ser_uint256_vector(self.vHash) + vBytesArray = bytearray([0x00] * ((len(self.vBits) + 7)//8)) + for i in range(len(self.vBits)): + vBytesArray[i // 8] |= self.vBits[i] << (i % 8) + r += ser_string(bytes(vBytesArray)) + return r + + def __repr__(self): + return "CPartialMerkleTree(nTransactions=%d, vHash=%s, vBits=%s)" % (self.nTransactions, repr(self.vHash), repr(self.vBits)) + +class CMerkleBlock(): + def __init__(self): + self.header = CBlockHeader() + self.txn = CPartialMerkleTree() + + def deserialize(self, f): + self.header.deserialize(f) + self.txn.deserialize(f) + + def serialize(self): + r = b"" + r += self.header.serialize() + r += self.txn.serialize() + return r + + def __repr__(self): + return "CMerkleBlock(header=%s, txn=%s)" % (repr(self.header), repr(self.txn)) + # Objects that correspond to messages on the wire class msg_version(): |