diff options
-rw-r--r-- | src/wallet/rpc/backup.cpp | 6 | ||||
-rw-r--r-- | src/wallet/rpc/encrypt.cpp | 26 | ||||
-rw-r--r-- | src/wallet/rpc/transactions.cpp | 5 | ||||
-rw-r--r-- | src/wallet/wallet.cpp | 6 | ||||
-rw-r--r-- | src/wallet/wallet.h | 9 | ||||
-rwxr-xr-x | test/functional/wallet_importdescriptors.py | 28 | ||||
-rwxr-xr-x | test/functional/wallet_transactiontime_rescan.py | 39 |
7 files changed, 106 insertions, 13 deletions
diff --git a/src/wallet/rpc/backup.cpp b/src/wallet/rpc/backup.cpp index e3701fe312..744537cfbd 100644 --- a/src/wallet/rpc/backup.cpp +++ b/src/wallet/rpc/backup.cpp @@ -1650,10 +1650,14 @@ RPCHelpMan importdescriptors() } WalletRescanReserver reserver(*pwallet); - if (!reserver.reserve()) { + if (!reserver.reserve(/*with_passphrase=*/true)) { throw JSONRPCError(RPC_WALLET_ERROR, "Wallet is currently rescanning. Abort existing rescan or wait."); } + // Ensure that the wallet is not locked for the remainder of this RPC, as + // the passphrase is used to top up the keypool. + LOCK(pwallet->m_relock_mutex); + const UniValue& requests = main_request.params[0]; const int64_t minimum_timestamp = 1; int64_t now = 0; diff --git a/src/wallet/rpc/encrypt.cpp b/src/wallet/rpc/encrypt.cpp index fcf25e01d6..38960bda7b 100644 --- a/src/wallet/rpc/encrypt.cpp +++ b/src/wallet/rpc/encrypt.cpp @@ -90,7 +90,7 @@ RPCHelpMan walletpassphrase() std::weak_ptr<CWallet> weak_wallet = wallet; pwallet->chain().rpcRunLater(strprintf("lockwallet(%s)", pwallet->GetName()), [weak_wallet, relock_time] { if (auto shared_wallet = weak_wallet.lock()) { - LOCK(shared_wallet->cs_wallet); + LOCK2(shared_wallet->m_relock_mutex, shared_wallet->cs_wallet); // Skip if this is not the most recent rpcRunLater callback. if (shared_wallet->nRelockTime != relock_time) return; shared_wallet->Lock(); @@ -122,12 +122,16 @@ RPCHelpMan walletpassphrasechange() std::shared_ptr<CWallet> const pwallet = GetWalletForJSONRPCRequest(request); if (!pwallet) return UniValue::VNULL; - LOCK(pwallet->cs_wallet); - if (!pwallet->IsCrypted()) { throw JSONRPCError(RPC_WALLET_WRONG_ENC_STATE, "Error: running with an unencrypted wallet, but walletpassphrasechange was called."); } + if (pwallet->IsScanningWithPassphrase()) { + throw JSONRPCError(RPC_WALLET_ERROR, "Error: the wallet is currently being used to rescan the blockchain for related transactions. Please call `abortrescan` before changing the passphrase."); + } + + LOCK2(pwallet->m_relock_mutex, pwallet->cs_wallet); + // TODO: get rid of these .c_str() calls by implementing SecureString::operator=(std::string) // Alternately, find a way to make request.params[0] mlock()'d to begin with. SecureString strOldWalletPass; @@ -175,12 +179,16 @@ RPCHelpMan walletlock() std::shared_ptr<CWallet> const pwallet = GetWalletForJSONRPCRequest(request); if (!pwallet) return UniValue::VNULL; - LOCK(pwallet->cs_wallet); - if (!pwallet->IsCrypted()) { throw JSONRPCError(RPC_WALLET_WRONG_ENC_STATE, "Error: running with an unencrypted wallet, but walletlock was called."); } + if (pwallet->IsScanningWithPassphrase()) { + throw JSONRPCError(RPC_WALLET_ERROR, "Error: the wallet is currently being used to rescan the blockchain for related transactions. Please call `abortrescan` before locking the wallet."); + } + + LOCK2(pwallet->m_relock_mutex, pwallet->cs_wallet); + pwallet->Lock(); pwallet->nRelockTime = 0; @@ -219,8 +227,6 @@ RPCHelpMan encryptwallet() std::shared_ptr<CWallet> const pwallet = GetWalletForJSONRPCRequest(request); if (!pwallet) return UniValue::VNULL; - LOCK(pwallet->cs_wallet); - if (pwallet->IsWalletFlagSet(WALLET_FLAG_DISABLE_PRIVATE_KEYS)) { throw JSONRPCError(RPC_WALLET_ENCRYPTION_FAILED, "Error: wallet does not contain private keys, nothing to encrypt."); } @@ -229,6 +235,12 @@ RPCHelpMan encryptwallet() throw JSONRPCError(RPC_WALLET_WRONG_ENC_STATE, "Error: running with an encrypted wallet, but encryptwallet was called."); } + if (pwallet->IsScanningWithPassphrase()) { + throw JSONRPCError(RPC_WALLET_ERROR, "Error: the wallet is currently being used to rescan the blockchain for related transactions. Please call `abortrescan` before encrypting the wallet."); + } + + LOCK2(pwallet->m_relock_mutex, pwallet->cs_wallet); + // TODO: get rid of this .c_str() by implementing SecureString::operator=(std::string) // Alternately, find a way to make request.params[0] mlock()'d to begin with. SecureString strWalletPass; diff --git a/src/wallet/rpc/transactions.cpp b/src/wallet/rpc/transactions.cpp index e590aa1f08..3bfe296d90 100644 --- a/src/wallet/rpc/transactions.cpp +++ b/src/wallet/rpc/transactions.cpp @@ -872,15 +872,18 @@ RPCHelpMan rescanblockchain() wallet.BlockUntilSyncedToCurrentChain(); WalletRescanReserver reserver(*pwallet); - if (!reserver.reserve()) { + if (!reserver.reserve(/*with_passphrase=*/true)) { throw JSONRPCError(RPC_WALLET_ERROR, "Wallet is currently rescanning. Abort existing rescan or wait."); } int start_height = 0; std::optional<int> stop_height; uint256 start_block; + + LOCK(pwallet->m_relock_mutex); { LOCK(pwallet->cs_wallet); + EnsureWalletIsUnlocked(*pwallet); int tip_height = pwallet->GetLastBlockHeight(); if (!request.params[0].isNull()) { diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp index eed40f9462..f2c0fdcb3d 100644 --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -552,7 +552,7 @@ bool CWallet::ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase, bool fWasLocked = IsLocked(); { - LOCK(cs_wallet); + LOCK2(m_relock_mutex, cs_wallet); Lock(); CCrypter crypter; @@ -787,7 +787,7 @@ bool CWallet::EncryptWallet(const SecureString& strWalletPassphrase) return false; { - LOCK(cs_wallet); + LOCK2(m_relock_mutex, cs_wallet); mapMasterKeys[++nMasterKeyMaxID] = kMasterKey; WalletBatch* encrypted_batch = new WalletBatch(GetDatabase()); if (!encrypted_batch->TxnBegin()) { @@ -3412,7 +3412,7 @@ bool CWallet::Lock() return false; { - LOCK(cs_wallet); + LOCK2(m_relock_mutex, cs_wallet); if (!vMasterKey.empty()) { memory_cleanse(vMasterKey.data(), vMasterKey.size() * sizeof(decltype(vMasterKey)::value_type)); vMasterKey.clear(); diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h index 6ec95220d6..5dd694e114 100644 --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -243,6 +243,7 @@ private: std::atomic<bool> fAbortRescan{false}; std::atomic<bool> fScanningWallet{false}; // controlled by WalletRescanReserver std::atomic<bool> m_attaching_chain{false}; + std::atomic<bool> m_scanning_with_passphrase{false}; std::atomic<int64_t> m_scanning_start{0}; std::atomic<double> m_scanning_progress{0}; friend class WalletRescanReserver; @@ -463,6 +464,7 @@ public: void AbortRescan() { fAbortRescan = true; } bool IsAbortingRescan() const { return fAbortRescan; } bool IsScanning() const { return fScanningWallet; } + bool IsScanningWithPassphrase() const { return m_scanning_with_passphrase; } int64_t ScanningDuration() const { return fScanningWallet ? GetTimeMillis() - m_scanning_start : 0; } double ScanningProgress() const { return fScanningWallet ? (double) m_scanning_progress : 0; } @@ -482,6 +484,9 @@ public: // Used to prevent concurrent calls to walletpassphrase RPC. Mutex m_unlock_mutex; + // Used to prevent deleting the passphrase from memory when it is still in use. + RecursiveMutex m_relock_mutex; + bool Unlock(const SecureString& strWalletPassphrase, bool accept_no_keys = false); bool ChangeWalletPassphrase(const SecureString& strOldWalletPassphrase, const SecureString& strNewWalletPassphrase); bool EncryptWallet(const SecureString& strWalletPassphrase); @@ -962,12 +967,13 @@ private: public: explicit WalletRescanReserver(CWallet& w) : m_wallet(w) {} - bool reserve() + bool reserve(bool with_passphrase = false) { assert(!m_could_reserve); if (m_wallet.fScanningWallet.exchange(true)) { return false; } + m_wallet.m_scanning_with_passphrase.exchange(with_passphrase); m_wallet.m_scanning_start = GetTimeMillis(); m_wallet.m_scanning_progress = 0; m_could_reserve = true; @@ -987,6 +993,7 @@ public: { if (m_could_reserve) { m_wallet.fScanningWallet = false; + m_wallet.m_scanning_with_passphrase = false; } } }; diff --git a/test/functional/wallet_importdescriptors.py b/test/functional/wallet_importdescriptors.py index 9e830088ac..e66eb2c289 100755 --- a/test/functional/wallet_importdescriptors.py +++ b/test/functional/wallet_importdescriptors.py @@ -667,5 +667,33 @@ class ImportDescriptorsTest(BitcoinTestFramework): success=True, warnings=["Unknown output type, cannot set descriptor to active."]) + self.log.info("Test importing a descriptor to an encrypted wallet") + + descriptor = {"desc": descsum_create("pkh(" + xpriv + "/1h/*h)"), + "timestamp": "now", + "active": True, + "range": [0,4000], + "next_index": 4000} + + self.nodes[0].createwallet("temp_wallet", blank=True, descriptors=True) + temp_wallet = self.nodes[0].get_wallet_rpc("temp_wallet") + temp_wallet.importdescriptors([descriptor]) + self.generatetoaddress(self.nodes[0], COINBASE_MATURITY + 1, temp_wallet.getnewaddress()) + self.generatetoaddress(self.nodes[0], COINBASE_MATURITY + 1, temp_wallet.getnewaddress()) + + self.nodes[0].createwallet("encrypted_wallet", blank=True, descriptors=True, passphrase="passphrase") + encrypted_wallet = self.nodes[0].get_wallet_rpc("encrypted_wallet") + + descriptor["timestamp"] = 0 + descriptor["next_index"] = 0 + + batch = [] + batch.append(encrypted_wallet.walletpassphrase.get_request("passphrase", 3)) + batch.append(encrypted_wallet.importdescriptors.get_request([descriptor])) + + encrypted_wallet.batch(batch) + + assert_equal(temp_wallet.getbalance(), encrypted_wallet.getbalance()) + if __name__ == '__main__': ImportDescriptorsTest().main() diff --git a/test/functional/wallet_transactiontime_rescan.py b/test/functional/wallet_transactiontime_rescan.py index de9616b4a1..904013cdef 100755 --- a/test/functional/wallet_transactiontime_rescan.py +++ b/test/functional/wallet_transactiontime_rescan.py @@ -14,6 +14,9 @@ from test_framework.util import ( assert_raises_rpc_error, set_node_times, ) +from test_framework.wallet_util import ( + get_generate_key, +) class TransactionTimeRescanTest(BitcoinTestFramework): @@ -23,6 +26,10 @@ class TransactionTimeRescanTest(BitcoinTestFramework): def set_test_params(self): self.setup_clean_chain = False self.num_nodes = 3 + self.extra_args = [["-keypool=400"], + ["-keypool=400"], + [] + ] def skip_test_if_missing_module(self): self.skip_if_no_wallet() @@ -167,6 +174,38 @@ class TransactionTimeRescanTest(BitcoinTestFramework): assert_raises_rpc_error(-8, "Invalid stop_height", restorewo_wallet.rescanblockchain, 1, -1) assert_raises_rpc_error(-8, "stop_height must be greater than start_height", restorewo_wallet.rescanblockchain, 20, 10) + self.log.info("Test `rescanblockchain` fails when wallet is encrypted and locked") + usernode.createwallet(wallet_name="enc_wallet", passphrase="passphrase") + enc_wallet = usernode.get_wallet_rpc("enc_wallet") + assert_raises_rpc_error(-13, "Error: Please enter the wallet passphrase with walletpassphrase first.", enc_wallet.rescanblockchain) + + if not self.options.descriptors: + self.log.info("Test rescanning an encrypted wallet") + hd_seed = get_generate_key().privkey + + usernode.createwallet(wallet_name="temp_wallet", blank=True, descriptors=False) + temp_wallet = usernode.get_wallet_rpc("temp_wallet") + temp_wallet.sethdseed(seed=hd_seed) + + for i in range(399): + temp_wallet.getnewaddress() + + self.generatetoaddress(usernode, COINBASE_MATURITY + 1, temp_wallet.getnewaddress()) + self.generatetoaddress(usernode, COINBASE_MATURITY + 1, temp_wallet.getnewaddress()) + + minernode.createwallet("encrypted_wallet", blank=True, passphrase="passphrase", descriptors=False) + encrypted_wallet = minernode.get_wallet_rpc("encrypted_wallet") + + encrypted_wallet.walletpassphrase("passphrase", 1) + encrypted_wallet.sethdseed(seed=hd_seed) + + batch = [] + batch.append(encrypted_wallet.walletpassphrase.get_request("passphrase", 3)) + batch.append(encrypted_wallet.rescanblockchain.get_request()) + + encrypted_wallet.batch(batch) + + assert_equal(encrypted_wallet.getbalance(), temp_wallet.getbalance()) if __name__ == '__main__': TransactionTimeRescanTest().main() |