#!/usr/bin/env python3 # Copyright (c) 2018-2021 The Bitcoin Core developers # Distributed under the MIT software license, see the accompanying # file COPYING or http://www.opensource.org/licenses/mit-license.php. """Useful util functions for testing the wallet""" from collections import namedtuple import unittest from test_framework.address import ( byte_to_base58, key_to_p2pkh, key_to_p2sh_p2wpkh, key_to_p2wpkh, script_to_p2sh, script_to_p2sh_p2wsh, script_to_p2wsh, ) from test_framework.key import ECKey from test_framework.messages import ( CTxIn, CTxInWitness, WITNESS_SCALE_FACTOR, ) from test_framework.script_util import ( key_to_p2pkh_script, key_to_p2wpkh_script, keys_to_multisig_script, script_to_p2sh_script, script_to_p2wsh_script, ) Key = namedtuple('Key', ['privkey', 'pubkey', 'p2pkh_script', 'p2pkh_addr', 'p2wpkh_script', 'p2wpkh_addr', 'p2sh_p2wpkh_script', 'p2sh_p2wpkh_redeem_script', 'p2sh_p2wpkh_addr']) Multisig = namedtuple('Multisig', ['privkeys', 'pubkeys', 'p2sh_script', 'p2sh_addr', 'redeem_script', 'p2wsh_script', 'p2wsh_addr', 'p2sh_p2wsh_script', 'p2sh_p2wsh_addr']) def get_key(node): """Generate a fresh key on node Returns a named tuple of privkey, pubkey and all address and scripts.""" addr = node.getnewaddress() pubkey = node.getaddressinfo(addr)['pubkey'] return Key(privkey=node.dumpprivkey(addr), pubkey=pubkey, p2pkh_script=key_to_p2pkh_script(pubkey).hex(), p2pkh_addr=key_to_p2pkh(pubkey), p2wpkh_script=key_to_p2wpkh_script(pubkey).hex(), p2wpkh_addr=key_to_p2wpkh(pubkey), p2sh_p2wpkh_script=script_to_p2sh_script(key_to_p2wpkh_script(pubkey)).hex(), p2sh_p2wpkh_redeem_script=key_to_p2wpkh_script(pubkey).hex(), p2sh_p2wpkh_addr=key_to_p2sh_p2wpkh(pubkey)) def get_generate_key(): """Generate a fresh key Returns a named tuple of privkey, pubkey and all address and scripts.""" privkey, pubkey = generate_keypair(wif=True) return Key(privkey=privkey, pubkey=pubkey.hex(), p2pkh_script=key_to_p2pkh_script(pubkey).hex(), p2pkh_addr=key_to_p2pkh(pubkey), p2wpkh_script=key_to_p2wpkh_script(pubkey).hex(), p2wpkh_addr=key_to_p2wpkh(pubkey), p2sh_p2wpkh_script=script_to_p2sh_script(key_to_p2wpkh_script(pubkey)).hex(), p2sh_p2wpkh_redeem_script=key_to_p2wpkh_script(pubkey).hex(), p2sh_p2wpkh_addr=key_to_p2sh_p2wpkh(pubkey)) def get_multisig(node): """Generate a fresh 2-of-3 multisig on node Returns a named tuple of privkeys, pubkeys and all address and scripts.""" addrs = [] pubkeys = [] for _ in range(3): addr = node.getaddressinfo(node.getnewaddress()) addrs.append(addr['address']) pubkeys.append(addr['pubkey']) script_code = keys_to_multisig_script(pubkeys, k=2) witness_script = script_to_p2wsh_script(script_code) return Multisig(privkeys=[node.dumpprivkey(addr) for addr in addrs], pubkeys=pubkeys, p2sh_script=script_to_p2sh_script(script_code).hex(), p2sh_addr=script_to_p2sh(script_code), redeem_script=script_code.hex(), p2wsh_script=witness_script.hex(), p2wsh_addr=script_to_p2wsh(script_code), p2sh_p2wsh_script=script_to_p2sh_script(witness_script).hex(), p2sh_p2wsh_addr=script_to_p2sh_p2wsh(script_code)) def test_address(node, address, **kwargs): """Get address info for `address` and test whether the returned values are as expected.""" addr_info = node.getaddressinfo(address) for key, value in kwargs.items(): if value is None: if key in addr_info.keys(): raise AssertionError("key {} unexpectedly returned in getaddressinfo.".format(key)) elif addr_info[key] != value: raise AssertionError("key {} value {} did not match expected value {}".format(key, addr_info[key], value)) def bytes_to_wif(b, compressed=True): if compressed: b += b'\x01' return byte_to_base58(b, 239) def generate_keypair(compressed=True, wif=False): """Generate a new random keypair and return the corresponding ECKey / bytes objects. The private key can also be provided as WIF (wallet import format) string instead, which is often useful for wallet RPC interaction.""" privkey = ECKey() privkey.generate(compressed) pubkey = privkey.get_pubkey().get_bytes() if wif: privkey = bytes_to_wif(privkey.get_bytes(), compressed) return privkey, pubkey def calculate_input_weight(scriptsig_hex, witness_stack_hex=None): """Given a scriptSig and a list of witness stack items for an input in hex format, calculate the total input weight. If the input has no witness data, `witness_stack_hex` can be set to None.""" tx_in = CTxIn(scriptSig=bytes.fromhex(scriptsig_hex)) witness_size = 0 if witness_stack_hex is not None: tx_inwit = CTxInWitness() for witness_item_hex in witness_stack_hex: tx_inwit.scriptWitness.stack.append(bytes.fromhex(witness_item_hex)) witness_size = len(tx_inwit.serialize()) return len(tx_in.serialize()) * WITNESS_SCALE_FACTOR + witness_size class WalletUnlock(): """ A context manager for unlocking a wallet with a passphrase and automatically locking it afterward. """ MAXIMUM_TIMEOUT = 999000 def __init__(self, wallet, passphrase, timeout=MAXIMUM_TIMEOUT): self.wallet = wallet self.passphrase = passphrase self.timeout = timeout def __enter__(self): self.wallet.walletpassphrase(self.passphrase, self.timeout) def __exit__(self, *args): _ = args self.wallet.walletlock() class TestFrameworkWalletUtil(unittest.TestCase): def test_calculate_input_weight(self): SKELETON_BYTES = 32 + 4 + 4 # prevout-txid, prevout-index, sequence SMALL_LEN_BYTES = 1 # bytes needed for encoding scriptSig / witness item lengths < 253 LARGE_LEN_BYTES = 3 # bytes needed for encoding scriptSig / witness item lengths >= 253 # empty scriptSig, no witness self.assertEqual(calculate_input_weight(""), (SKELETON_BYTES + SMALL_LEN_BYTES) * WITNESS_SCALE_FACTOR) self.assertEqual(calculate_input_weight("", None), (SKELETON_BYTES + SMALL_LEN_BYTES) * WITNESS_SCALE_FACTOR) # small scriptSig, no witness scriptSig_small = "00"*252 self.assertEqual(calculate_input_weight(scriptSig_small, None), (SKELETON_BYTES + SMALL_LEN_BYTES + 252) * WITNESS_SCALE_FACTOR) # small scriptSig, empty witness stack self.assertEqual(calculate_input_weight(scriptSig_small, []), (SKELETON_BYTES + SMALL_LEN_BYTES + 252) * WITNESS_SCALE_FACTOR + SMALL_LEN_BYTES) # large scriptSig, no witness scriptSig_large = "00"*253 self.assertEqual(calculate_input_weight(scriptSig_large, None), (SKELETON_BYTES + LARGE_LEN_BYTES + 253) * WITNESS_SCALE_FACTOR) # large scriptSig, empty witness stack self.assertEqual(calculate_input_weight(scriptSig_large, []), (SKELETON_BYTES + LARGE_LEN_BYTES + 253) * WITNESS_SCALE_FACTOR + SMALL_LEN_BYTES) # empty scriptSig, 5 small witness stack items self.assertEqual(calculate_input_weight("", ["00", "11", "22", "33", "44"]), ((SKELETON_BYTES + SMALL_LEN_BYTES) * WITNESS_SCALE_FACTOR) + SMALL_LEN_BYTES + 5 * SMALL_LEN_BYTES + 5) # empty scriptSig, 253 small witness stack items self.assertEqual(calculate_input_weight("", ["00"]*253), ((SKELETON_BYTES + SMALL_LEN_BYTES) * WITNESS_SCALE_FACTOR) + LARGE_LEN_BYTES + 253 * SMALL_LEN_BYTES + 253) # small scriptSig, 3 large witness stack items self.assertEqual(calculate_input_weight(scriptSig_small, ["00"*253]*3), ((SKELETON_BYTES + SMALL_LEN_BYTES + 252) * WITNESS_SCALE_FACTOR) + SMALL_LEN_BYTES + 3 * LARGE_LEN_BYTES + 3*253) # large scriptSig, 3 large witness stack items self.assertEqual(calculate_input_weight(scriptSig_large, ["00"*253]*3), ((SKELETON_BYTES + LARGE_LEN_BYTES + 253) * WITNESS_SCALE_FACTOR) + SMALL_LEN_BYTES + 3 * LARGE_LEN_BYTES + 3*253)