diff options
Diffstat (limited to 'src/test/fuzz/util.h')
-rw-r--r-- | src/test/fuzz/util.h | 44 |
1 files changed, 38 insertions, 6 deletions
diff --git a/src/test/fuzz/util.h b/src/test/fuzz/util.h index ff79dfe5f3..4e8e501895 100644 --- a/src/test/fuzz/util.h +++ b/src/test/fuzz/util.h @@ -22,6 +22,7 @@ #include <streams.h> #include <test/fuzz/FuzzedDataProvider.h> #include <test/fuzz/fuzz.h> +#include <test/util/net.h> #include <test/util/setup_common.h> #include <txmempool.h> #include <uint256.h> @@ -86,6 +87,14 @@ template <typename T> return obj; } +template <typename WeakEnumType, size_t size> +[[nodiscard]] WeakEnumType ConsumeWeakEnum(FuzzedDataProvider& fuzzed_data_provider, const WeakEnumType (&all_types)[size]) noexcept +{ + return fuzzed_data_provider.ConsumeBool() ? + fuzzed_data_provider.PickValueInArray<WeakEnumType>(all_types) : + WeakEnumType(fuzzed_data_provider.ConsumeIntegral<typename std::underlying_type<WeakEnumType>::type>()); +} + [[nodiscard]] inline opcodetype ConsumeOpcodeType(FuzzedDataProvider& fuzzed_data_provider) noexcept { return static_cast<opcodetype>(fuzzed_data_provider.ConsumeIntegralInRange<uint32_t>(0, MAX_OPCODE)); @@ -283,22 +292,45 @@ inline CService ConsumeService(FuzzedDataProvider& fuzzed_data_provider) noexcep inline CAddress ConsumeAddress(FuzzedDataProvider& fuzzed_data_provider) noexcept { - return {ConsumeService(fuzzed_data_provider), static_cast<ServiceFlags>(fuzzed_data_provider.ConsumeIntegral<uint64_t>()), fuzzed_data_provider.ConsumeIntegral<uint32_t>()}; + return {ConsumeService(fuzzed_data_provider), ConsumeWeakEnum(fuzzed_data_provider, ALL_SERVICE_FLAGS), fuzzed_data_provider.ConsumeIntegral<uint32_t>()}; } -inline CNode ConsumeNode(FuzzedDataProvider& fuzzed_data_provider) noexcept +template <bool ReturnUniquePtr = false> +auto ConsumeNode(FuzzedDataProvider& fuzzed_data_provider, const std::optional<NodeId>& node_id_in = std::nullopt) noexcept { - const NodeId node_id = fuzzed_data_provider.ConsumeIntegral<NodeId>(); - const ServiceFlags local_services = static_cast<ServiceFlags>(fuzzed_data_provider.ConsumeIntegral<uint64_t>()); + const NodeId node_id = node_id_in.value_or(fuzzed_data_provider.ConsumeIntegral<NodeId>()); + const ServiceFlags local_services = ConsumeWeakEnum(fuzzed_data_provider, ALL_SERVICE_FLAGS); const SOCKET socket = INVALID_SOCKET; const CAddress address = ConsumeAddress(fuzzed_data_provider); const uint64_t keyed_net_group = fuzzed_data_provider.ConsumeIntegral<uint64_t>(); const uint64_t local_host_nonce = fuzzed_data_provider.ConsumeIntegral<uint64_t>(); const CAddress addr_bind = ConsumeAddress(fuzzed_data_provider); const std::string addr_name = fuzzed_data_provider.ConsumeRandomLengthString(64); - const ConnectionType conn_type = fuzzed_data_provider.PickValueInArray({ConnectionType::INBOUND, ConnectionType::OUTBOUND_FULL_RELAY, ConnectionType::MANUAL, ConnectionType::FEELER, ConnectionType::BLOCK_RELAY, ConnectionType::ADDR_FETCH}); + const ConnectionType conn_type = fuzzed_data_provider.PickValueInArray(ALL_CONNECTION_TYPES); const bool inbound_onion{conn_type == ConnectionType::INBOUND ? fuzzed_data_provider.ConsumeBool() : false}; - return {node_id, local_services, socket, address, keyed_net_group, local_host_nonce, addr_bind, addr_name, conn_type, inbound_onion}; + if constexpr (ReturnUniquePtr) { + return std::make_unique<CNode>(node_id, local_services, socket, address, keyed_net_group, local_host_nonce, addr_bind, addr_name, conn_type, inbound_onion); + } else { + return CNode{node_id, local_services, socket, address, keyed_net_group, local_host_nonce, addr_bind, addr_name, conn_type, inbound_onion}; + } +} +inline std::unique_ptr<CNode> ConsumeNodeAsUniquePtr(FuzzedDataProvider& fdp, const std::optional<NodeId>& node_id_in = nullopt) { return ConsumeNode<true>(fdp, node_id_in); } + +inline void FillNode(FuzzedDataProvider& fuzzed_data_provider, CNode& node, const std::optional<int32_t>& version_in = std::nullopt) noexcept +{ + const ServiceFlags remote_services = ConsumeWeakEnum(fuzzed_data_provider, ALL_SERVICE_FLAGS); + const NetPermissionFlags permission_flags = ConsumeWeakEnum(fuzzed_data_provider, ALL_NET_PERMISSION_FLAGS); + const int32_t version = version_in.value_or(fuzzed_data_provider.ConsumeIntegral<int32_t>()); + const bool filter_txs = fuzzed_data_provider.ConsumeBool(); + + node.nServices = remote_services; + node.m_permissionFlags = permission_flags; + node.nVersion = version; + node.SetCommonVersion(version); + if (node.m_tx_relay != nullptr) { + LOCK(node.m_tx_relay->cs_filter); + node.m_tx_relay->fRelayTxes = filter_txs; + } } inline void InitializeFuzzingContext(const std::string& chain_name = CBaseChainParams::REGTEST) |