From ab1ea29ba1b8379a21fabd3dc859552c470a6421 Mon Sep 17 00:00:00 2001 From: pasta Date: Mon, 31 Jan 2022 19:32:59 +0700 Subject: refactor: make GetRand a template, remove GetRandInt --- src/addrdb.cpp | 3 +-- src/blockencodings.cpp | 2 +- src/common/bloom.cpp | 2 +- src/init.cpp | 4 ++-- src/net.cpp | 2 +- src/net_processing.cpp | 8 ++++---- src/netaddress.h | 4 ++-- src/random.cpp | 7 +------ src/random.h | 13 +++++++++++-- src/test/random_tests.cpp | 8 ++++---- src/util/bytevectorhash.cpp | 6 +++--- src/util/hasher.cpp | 6 +++--- 12 files changed, 34 insertions(+), 31 deletions(-) (limited to 'src') diff --git a/src/addrdb.cpp b/src/addrdb.cpp index 0a76f83150..299cbdcf6a 100644 --- a/src/addrdb.cpp +++ b/src/addrdb.cpp @@ -49,8 +49,7 @@ template bool SerializeFileDB(const std::string& prefix, const fs::path& path, const Data& data, int version) { // Generate random temporary filename - uint16_t randv = 0; - GetRandBytes({(unsigned char*)&randv, sizeof(randv)}); + const uint16_t randv{GetRand()}; std::string tmpfn = strprintf("%s.%04x", prefix, randv); // open temp output file, and associate with CAutoFile diff --git a/src/blockencodings.cpp b/src/blockencodings.cpp index aa111b5939..2a7bf9397c 100644 --- a/src/blockencodings.cpp +++ b/src/blockencodings.cpp @@ -17,7 +17,7 @@ #include CBlockHeaderAndShortTxIDs::CBlockHeaderAndShortTxIDs(const CBlock& block, bool fUseWTXID) : - nonce(GetRand(std::numeric_limits::max())), + nonce(GetRand()), shorttxids(block.vtx.size() - 1), prefilledtxn(1), header(block) { FillShortTxIDSelector(); //TODO: Use our mempool prior to block acceptance to predictively fill more than just the coinbase diff --git a/src/common/bloom.cpp b/src/common/bloom.cpp index 8b32a6c94a..aa3fcf1ce2 100644 --- a/src/common/bloom.cpp +++ b/src/common/bloom.cpp @@ -239,7 +239,7 @@ bool CRollingBloomFilter::contains(Span vKey) const void CRollingBloomFilter::reset() { - nTweak = GetRand(std::numeric_limits::max()); + nTweak = GetRand(); nEntriesThisGeneration = 0; nGeneration = 1; std::fill(data.begin(), data.end(), 0); diff --git a/src/init.cpp b/src/init.cpp index aa1cff761e..dc99e78555 100644 --- a/src/init.cpp +++ b/src/init.cpp @@ -1268,8 +1268,8 @@ bool AppInitMain(NodeContext& node, interfaces::BlockAndHeaderTipInfo* tip_info) assert(!node.banman); node.banman = std::make_unique(gArgs.GetDataDirNet() / "banlist", &uiInterface, args.GetIntArg("-bantime", DEFAULT_MISBEHAVING_BANTIME)); assert(!node.connman); - node.connman = std::make_unique(GetRand(std::numeric_limits::max()), - GetRand(std::numeric_limits::max()), + node.connman = std::make_unique(GetRand(), + GetRand(), *node.addrman, *node.netgroupman, args.GetBoolArg("-networkactive", true)); assert(!node.fee_estimator); diff --git a/src/net.cpp b/src/net.cpp index 77fa06ce26..586f7d671b 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -2154,7 +2154,7 @@ void CConnman::ThreadOpenConnections(const std::vector connect) if (fFeeler) { // Add small amount of random noise before connection to avoid synchronization. - int randsleep = GetRandInt(FEELER_SLEEP_WINDOW * 1000); + int randsleep = GetRand(FEELER_SLEEP_WINDOW * 1000); if (!interruptNet.sleep_for(std::chrono::milliseconds(randsleep))) return; LogPrint(BCLog::NET, "Making feeler connection to %s\n", addrConnect.ToString()); diff --git a/src/net_processing.cpp b/src/net_processing.cpp index 46b4d2e3df..df422fa8e3 100644 --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -4470,10 +4470,10 @@ void PeerManagerImpl::MaybeSendPing(CNode& node_to, Peer& peer, std::chrono::mic } if (pingSend) { - uint64_t nonce = 0; - while (nonce == 0) { - GetRandBytes({(unsigned char*)&nonce, sizeof(nonce)}); - } + uint64_t nonce; + do { + nonce = GetRand(); + } while (nonce == 0); peer.m_ping_queued = false; peer.m_ping_start = now; if (node_to.GetCommonVersion() > BIP0031_VERSION) { diff --git a/src/netaddress.h b/src/netaddress.h index b9a8dc589a..77e6171054 100644 --- a/src/netaddress.h +++ b/src/netaddress.h @@ -556,8 +556,8 @@ class CServiceHash { public: CServiceHash() - : m_salt_k0{GetRand(std::numeric_limits::max())}, - m_salt_k1{GetRand(std::numeric_limits::max())} + : m_salt_k0{GetRand()}, + m_salt_k1{GetRand()} { } diff --git a/src/random.cpp b/src/random.cpp index 6ae08103b1..ad8568bef0 100644 --- a/src/random.cpp +++ b/src/random.cpp @@ -586,16 +586,11 @@ void RandAddEvent(const uint32_t event_info) noexcept { GetRNGState().AddEvent(e bool g_mock_deterministic_tests{false}; -uint64_t GetRand(uint64_t nMax) noexcept +uint64_t GetRandInternal(uint64_t nMax) noexcept { return FastRandomContext(g_mock_deterministic_tests).randrange(nMax); } -int GetRandInt(int nMax) noexcept -{ - return GetRand(nMax); -} - uint256 GetRandHash() noexcept { uint256 hash; diff --git a/src/random.h b/src/random.h index 285158b1c3..40adf29010 100644 --- a/src/random.h +++ b/src/random.h @@ -69,7 +69,17 @@ */ void GetRandBytes(Span bytes) noexcept; /** Generate a uniform random integer in the range [0..range). Precondition: range > 0 */ -uint64_t GetRand(uint64_t nMax) noexcept; +uint64_t GetRandInternal(uint64_t nMax) noexcept; +/** Generate a uniform random integer of type T in the range [0..nMax) + * nMax defaults to std::numeric_limits::max() + * Precondition: nMax > 0, T is an integral type, no larger than uint64_t + */ +template +T GetRand(T nMax=std::numeric_limits::max()) noexcept { + static_assert(std::is_integral(), "T must be integral"); + static_assert(std::numeric_limits::max() <= std::numeric_limits::max(), "GetRand only supports up to uint64_t"); + return T(GetRandInternal(nMax)); +} /** Generate a uniform random duration in the range [0..max). Precondition: max.count() > 0 */ template D GetRandomDuration(typename std::common_type::type max) noexcept @@ -95,7 +105,6 @@ constexpr auto GetRandMillis = GetRandomDuration; * */ std::chrono::microseconds GetExponentialRand(std::chrono::microseconds now, std::chrono::seconds average_interval); -int GetRandInt(int nMax) noexcept; uint256 GetRandHash() noexcept; /** diff --git a/src/test/random_tests.cpp b/src/test/random_tests.cpp index 978a7bee4d..eba7b51592 100644 --- a/src/test/random_tests.cpp +++ b/src/test/random_tests.cpp @@ -26,8 +26,8 @@ BOOST_AUTO_TEST_CASE(fastrandom_tests) FastRandomContext ctx2(true); for (int i = 10; i > 0; --i) { - BOOST_CHECK_EQUAL(GetRand(std::numeric_limits::max()), uint64_t{10393729187455219830U}); - BOOST_CHECK_EQUAL(GetRandInt(std::numeric_limits::max()), int{769702006}); + BOOST_CHECK_EQUAL(GetRand(), uint64_t{10393729187455219830U}); + BOOST_CHECK_EQUAL(GetRand(), int{769702006}); BOOST_CHECK_EQUAL(GetRandMicros(std::chrono::hours{1}).count(), 2917185654); BOOST_CHECK_EQUAL(GetRandMillis(std::chrono::hours{1}).count(), 2144374); } @@ -47,8 +47,8 @@ BOOST_AUTO_TEST_CASE(fastrandom_tests) // Check that a nondeterministic ones are not g_mock_deterministic_tests = false; for (int i = 10; i > 0; --i) { - BOOST_CHECK(GetRand(std::numeric_limits::max()) != uint64_t{10393729187455219830U}); - BOOST_CHECK(GetRandInt(std::numeric_limits::max()) != int{769702006}); + BOOST_CHECK(GetRand() != uint64_t{10393729187455219830U}); + BOOST_CHECK(GetRand() != int{769702006}); BOOST_CHECK(GetRandMicros(std::chrono::hours{1}) != std::chrono::microseconds{2917185654}); BOOST_CHECK(GetRandMillis(std::chrono::hours{1}) != std::chrono::milliseconds{2144374}); } diff --git a/src/util/bytevectorhash.cpp b/src/util/bytevectorhash.cpp index bc060a44c9..9054db4759 100644 --- a/src/util/bytevectorhash.cpp +++ b/src/util/bytevectorhash.cpp @@ -6,10 +6,10 @@ #include #include -ByteVectorHash::ByteVectorHash() +ByteVectorHash::ByteVectorHash() : + m_k0(GetRand()), + m_k1(GetRand()) { - GetRandBytes({reinterpret_cast(&m_k0), sizeof(m_k0)}); - GetRandBytes({reinterpret_cast(&m_k1), sizeof(m_k1)}); } size_t ByteVectorHash::operator()(const std::vector& input) const diff --git a/src/util/hasher.cpp b/src/util/hasher.cpp index 5900daf050..c21941eb88 100644 --- a/src/util/hasher.cpp +++ b/src/util/hasher.cpp @@ -7,11 +7,11 @@ #include -SaltedTxidHasher::SaltedTxidHasher() : k0(GetRand(std::numeric_limits::max())), k1(GetRand(std::numeric_limits::max())) {} +SaltedTxidHasher::SaltedTxidHasher() : k0(GetRand()), k1(GetRand()) {} -SaltedOutpointHasher::SaltedOutpointHasher() : k0(GetRand(std::numeric_limits::max())), k1(GetRand(std::numeric_limits::max())) {} +SaltedOutpointHasher::SaltedOutpointHasher() : k0(GetRand()), k1(GetRand()) {} -SaltedSipHasher::SaltedSipHasher() : m_k0(GetRand(std::numeric_limits::max())), m_k1(GetRand(std::numeric_limits::max())) {} +SaltedSipHasher::SaltedSipHasher() : m_k0(GetRand()), m_k1(GetRand()) {} size_t SaltedSipHasher::operator()(const Span& script) const { -- cgit v1.2.3