diff --git a/src/jmclient/wallet.py b/src/jmclient/wallet.py index 094a326d7..b417f60d9 100644 --- a/src/jmclient/wallet.py +++ b/src/jmclient/wallet.py @@ -19,6 +19,7 @@ from decimal import Decimal from numbers import Integral from math import exp +from typing import Any, Dict, Optional, Tuple from .configure import jm_single @@ -280,32 +281,28 @@ def select_utxos(self, mixdepth, amount, utxo_filter=(), select_fn=None, 'value': utxos[s['utxo']][1]} for s in selected} - def get_balance_by_mixdepth(self, max_mixdepth=float('Inf'), - include_disabled=True, maxheight=None): - """ By default this returns a dict of aggregated bitcoin - balance per mixdepth: {0: N sats, 1: M sats, ...} for all - currently available mixdepths. - If max_mixdepth is set it will return balances only up - to that mixdepth. + def get_balance_at_mixdepth(self, mixdepth: int, + include_disabled: bool = True, + maxheight: Optional[int] = None) -> int: + """ By default this returns aggregated bitcoin balance at mixdepth. To get only enabled balance, set include_disabled=False. To get balances only with a certain number of confs, use maxheight. """ - balance_dict = collections.defaultdict(int) - for mixdepth, utxomap in self._utxo.items(): - if mixdepth > max_mixdepth: - continue - if not include_disabled: - utxomap = {k: v for k, v in utxomap.items( - ) if not self.is_disabled(*k)} - if maxheight is not None: - utxomap = {k: v for k, v in utxomap.items( - ) if v[2] <= maxheight} - value = sum(x[1] for x in utxomap.values()) - balance_dict[mixdepth] = value - return balance_dict - - def get_utxos_by_mixdepth(self): - return deepcopy(self._utxo) + utxomap = self._utxo.get(mixdepth) + if not utxomap: + return 0 + if not include_disabled: + utxomap = {k: v for k, v in utxomap.items( + ) if not self.is_disabled(*k)} + if maxheight is not None: + utxomap = {k: v for k, v in utxomap.items( + ) if v[2] <= maxheight} + return sum(x[1] for x in utxomap.values()) + + def get_utxos_at_mixdepth(self, mixdepth: int) -> \ + Dict[Tuple[bytes, int], Tuple[Tuple, int, int]]: + utxomap = self._utxo.get(mixdepth) + return deepcopy(utxomap) if utxomap else {} def __eq__(self, o): return self._utxo == o._utxo and \ @@ -377,6 +374,7 @@ def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, self._storage = storage self._utxos = None self._addr_labels = None + self._cache = None # highest mixdepth ever used in wallet, important for synching self.max_mixdepth = None # effective maximum mixdepth to be used by joinmarket @@ -385,10 +383,13 @@ def __init__(self, storage, gap_limit=6, merge_algorithm_name=None, # {script: path}, should always hold mappings for all "known" keys self._script_map = {} + # {address: path}, should always hold mappings for all "known" keys + self._addr_map = {} self._load_storage() assert self._utxos is not None + assert self._cache is not None assert self.max_mixdepth is not None assert self.max_mixdepth >= 0 assert self.network in ('mainnet', 'testnet', 'signet') @@ -425,6 +426,7 @@ def _load_storage(self): self.network = self._storage.data[b'network'].decode('ascii') self._utxos = UTXOManager(self._storage, self.merge_algorithm) self._addr_labels = AddressLabelsManager(self._storage) + self._cache = self._storage.data.setdefault(b'cache', {}) def get_storage_location(self): """ Return the location of the @@ -538,34 +540,24 @@ def get_key_from_addr(self, addr): """ There should be no reason for code outside the wallet to need a privkey. """ - script = self._ENGINE.address_to_script(addr) - path = self.script_to_path(script) + path = self.addr_to_path(addr) privkey = self._get_key_from_path(path)[0] return privkey - def _get_addr_int_ext(self, address_type, mixdepth): - if address_type == self.ADDRESS_TYPE_EXTERNAL: - script = self.get_external_script(mixdepth) - elif address_type == self.ADDRESS_TYPE_INTERNAL: - script = self.get_internal_script(mixdepth) - else: - assert 0 - return self.script_to_addr(script) - def get_external_addr(self, mixdepth): """ Return an address suitable for external distribution, including funding the wallet from other sources, or receiving payments or donations. JoinMarket will never generate these addresses for internal use. """ - return self._get_addr_int_ext(self.ADDRESS_TYPE_EXTERNAL, mixdepth) + return self.get_new_addr(mixdepth, self.ADDRESS_TYPE_EXTERNAL) def get_internal_addr(self, mixdepth): """ Return an address for internal usage, as change addresses and when participating in transactions initiated by other parties. """ - return self._get_addr_int_ext(self.ADDRESS_TYPE_INTERNAL, mixdepth) + return self.get_new_addr(mixdepth, self.ADDRESS_TYPE_INTERNAL) def get_external_script(self, mixdepth): return self.get_new_script(mixdepth, self.ADDRESS_TYPE_EXTERNAL) @@ -575,21 +567,33 @@ def get_internal_script(self, mixdepth): @classmethod def addr_to_script(cls, addr): + """ + Try not to call this slow method. Instead, call addr_to_path, + followed by get_script_from_path, as those are cached. + """ return cls._ENGINE.address_to_script(addr) @classmethod def pubkey_to_script(cls, pubkey): + """ + Try not to call this slow method. Instead, call + get_script_from_path if possible, as that is cached. + """ return cls._ENGINE.pubkey_to_script(pubkey) @classmethod def pubkey_to_addr(cls, pubkey): + """ + Try not to call this slow method. Instead, call + get_address_from_path if possible, as that is cached. + """ return cls._ENGINE.pubkey_to_address(pubkey) - def script_to_addr(self, script): - assert self.is_known_script(script) + def script_to_addr(self, script, + validate_cache: bool = False): path = self.script_to_path(script) - engine = self._get_key_from_path(path)[1] - return engine.script_to_address(script) + return self.get_address_from_path(path, + validate_cache=validate_cache) def get_script_code(self, script): """ @@ -600,8 +604,7 @@ def get_script_code(self, script): For non-segwit wallets, raises EngineError. """ path = self.script_to_path(script) - priv, engine = self._get_key_from_path(path) - pub = engine.privkey_to_pubkey(priv) + pub, engine = self._get_pubkey_from_path(path) return engine.pubkey_to_script_code(pub) @classmethod @@ -616,22 +619,42 @@ def pubkey_has_script(cls, pubkey, script): def get_key(self, mixdepth, address_type, index): raise NotImplementedError() - def get_addr(self, mixdepth, address_type, index): - script = self.get_script(mixdepth, address_type, index) - return self.script_to_addr(script) - - def get_address_from_path(self, path): - script = self.get_script_from_path(path) - return self.script_to_addr(script) - - def get_new_addr(self, mixdepth, address_type): + def get_addr(self, mixdepth, address_type, index, + validate_cache: bool = False): + path = self.get_path(mixdepth, address_type, index) + return self.get_address_from_path(path, + validate_cache=validate_cache) + + def get_address_from_path(self, path, + validate_cache: bool = False): + cache = self._get_cache_for_path(path) + addr = cache.get(b'A') + if addr is not None: + addr = addr.decode('ascii') + if addr is None or validate_cache: + engine = self._get_pubkey_from_path(path)[1] + script = self.get_script_from_path(path, + validate_cache=validate_cache) + new_addr = engine.script_to_address(script) + if addr is None: + addr = new_addr + cache[b'A'] = addr.encode('ascii') + elif addr != new_addr: + raise WalletError("Wallet cache validation failed") + return addr + + def get_new_addr(self, mixdepth, address_type, + validate_cache: bool = True): """ use get_external_addr/get_internal_addr """ - script = self.get_new_script(mixdepth, address_type) - return self.script_to_addr(script) + script = self.get_new_script(mixdepth, address_type, + validate_cache=validate_cache) + return self.script_to_addr(script, + validate_cache=validate_cache) - def get_new_script(self, mixdepth, address_type): + def get_new_script(self, mixdepth, address_type, + validate_cache: bool = True): raise NotImplementedError() def get_wif(self, mixdepth, address_type, index): @@ -845,10 +868,19 @@ def get_balance_by_mixdepth(self, verbose=True, confirmations, set maxheight to max acceptable blockheight. returns: {mixdepth: value} """ + balances = collections.defaultdict(int) + for md in range(self.mixdepth + 1): + balances[md] = self.get_balance_at_mixdepth(md, verbose=verbose, + include_disabled=include_disabled, maxheight=maxheight) + return balances + + def get_balance_at_mixdepth(self, mixdepth, + verbose: bool = True, + include_disabled: bool = False, + maxheight: Optional[int] = None) -> int: # TODO: verbose - return self._utxos.get_balance_by_mixdepth(max_mixdepth=self.mixdepth, - include_disabled=include_disabled, - maxheight=maxheight) + return self._utxos.get_balance_at_mixdepth(mixdepth, + include_disabled=include_disabled, maxheight=maxheight) def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False): """ @@ -859,25 +891,35 @@ def get_utxos_by_mixdepth(self, include_disabled=False, includeheight=False): {'script': bytes, 'path': tuple, 'value': int}}} (if `includeheight` is True, adds key 'height': int) """ - mix_utxos = self._utxos.get_utxos_by_mixdepth() - script_utxos = collections.defaultdict(dict) - for md, data in mix_utxos.items(): - if md > self.mixdepth: - continue + for md in range(self.mixdepth + 1): + script_utxos[md] = self.get_utxos_at_mixdepth(md, + include_disabled=include_disabled, includeheight=includeheight) + return script_utxos + + def get_utxos_at_mixdepth(self, mixdepth: int, + include_disabled: bool = False, + includeheight: bool = False) -> \ + Dict[Tuple[bytes, int], Dict[str, Any]]: + script_utxos = {} + if 0 <= mixdepth <= self.mixdepth: + data = self._utxos.get_utxos_at_mixdepth(mixdepth) for utxo, (path, value, height) in data.items(): if not include_disabled and self._utxos.is_disabled(*utxo): continue script = self.get_script_from_path(path) addr = self.get_address_from_path(path) label = self.get_address_label(addr) - script_utxos[md][utxo] = {'script': script, - 'path': path, - 'value': value, - 'address': addr, - 'label': label} + script_utxo = { + 'script': script, + 'path': path, + 'value': value, + 'address': addr, + 'label': label, + } if includeheight: - script_utxos[md][utxo]['height'] = height + script_utxo['height'] = height + script_utxos[utxo] = script_utxo return script_utxos @@ -910,7 +952,8 @@ def _get_merge_algorithm(cls, algorithm_name=None): def _get_mixdepth_from_path(self, path): raise NotImplementedError() - def get_script_from_path(self, path): + def get_script_from_path(self, path, + validate_cache: bool = False): """ internal note: This is the final sink for all operations that somehow need to derive a script. If anything goes wrong when deriving a @@ -921,15 +964,72 @@ def get_script_from_path(self, path): returns: script """ - raise NotImplementedError() + cache = self._get_cache_for_path(path) + script = cache.get(b'S') + if script is None or validate_cache: + pubkey, engine = self._get_pubkey_from_path(path, + validate_cache=validate_cache) + new_script = engine.pubkey_to_script(pubkey) + if script is None: + cache[b'S'] = script = new_script + elif script != new_script: + raise WalletError("Wallet cache validation failed") + return script - def get_script(self, mixdepth, address_type, index): + def get_script(self, mixdepth, address_type, index, + validate_cache: bool = False): path = self.get_path(mixdepth, address_type, index) - return self.get_script_from_path(path) + return self.get_script_from_path(path, validate_cache=validate_cache) - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): raise NotImplementedError() + def _get_keypair_from_path(self, path, + validate_cache: bool = False): + privkey, engine = self._get_key_from_path(path, + validate_cache=validate_cache) + cache = self._get_cache_for_path(path) + pubkey = cache.get(b'P') + if pubkey is None or validate_cache: + new_pubkey = engine.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") + return privkey, pubkey, engine + + def _get_pubkey_from_path(self, path, + validate_cache: bool = False): + privkey, pubkey, engine = self._get_keypair_from_path(path, + validate_cache=validate_cache) + return pubkey, engine + + def _get_cache_keys_for_path(self, path): + return path[:1] + tuple(map(_int_to_bytestr, path[1:])) + + def _get_cache_for_path(self, path): + assert len(path) > 0 + cache = self._cache + for k in self._get_cache_keys_for_path(path): + cache = cache.setdefault(k, {}) + return cache + + def _delete_cache_for_path(self, path) -> bool: + assert len(path) > 0 + def recurse(cache, itr): + k = next(itr, None) + if k is None: + cache.clear() + else: + child = cache.get(k) + if not child or not recurse(child, itr): + return False + if not child: + del cache[k] + return True + return recurse(self._cache, iter(self._get_cache_keys_for_path(path))) + def get_path_repr(self, path): """ Get a human-readable representation of the wallet path. @@ -984,7 +1084,7 @@ def sign_message(self, message, path): signature as base64-encoded string """ priv, engine = self._get_key_from_path(path) - addr = engine.privkey_to_address(priv) + addr = self.get_address_from_path(path) return addr, engine.sign_message(priv, message) def get_wallet_name(self): @@ -1038,8 +1138,8 @@ def is_known_addr(self, addr): returns: bool """ - script = self.addr_to_script(addr) - return script in self._script_map + assert isinstance(addr, str) + return addr in self._addr_map def is_known_script(self, script): """ @@ -1054,8 +1154,8 @@ def is_known_script(self, script): return script in self._script_map def get_addr_mixdepth(self, addr): - script = self.addr_to_script(addr) - return self.get_script_mixdepth(script) + path = self.addr_to_path(addr) + return self._get_mixdepth_from_path(path) def get_script_mixdepth(self, script): path = self.script_to_path(script) @@ -1068,16 +1168,26 @@ def yield_known_paths(self): returns: path generator """ - for s in self._script_map.values(): - yield s + for md in range(self.max_mixdepth + 1): + for path in self.yield_imported_paths(md): + yield path + + def _populate_maps(self, paths): + for path in paths: + self._script_map[self.get_script_from_path(path)] = path + self._addr_map[self.get_address_from_path(path)] = path def addr_to_path(self, addr): - script = self.addr_to_script(addr) - return self.script_to_path(script) + assert isinstance(addr, str) + path = self._addr_map.get(addr) + assert path is not None + return path def script_to_path(self, script): - assert script in self._script_map - return self._script_map[script] + assert isinstance(script, bytes) + path = self._script_map.get(script) + assert path is not None + return path def set_next_index(self, mixdepth, address_type, index, force=False): """ @@ -1379,9 +1489,8 @@ def create_psbt_from_tx(self, tx, spent_outs=None, force_witness_utxo=True): # this happens when an input is provided but it's not in # this wallet; in this case, we cannot set the redeem script. continue - privkey, _ = self._get_key_from_path(path) - txinput.redeem_script = btc.pubkey_to_p2wpkh_script( - btc.privkey_to_pubkey(privkey)) + pubkey = self._get_pubkey_from_path(path)[0] + txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey) return new_psbt def sign_psbt(self, in_psbt, with_sign_result=False): @@ -1451,9 +1560,8 @@ def sign_psbt(self, in_psbt, with_sign_result=False): # this happens when an input is provided but it's not in # this wallet; in this case, we cannot set the redeem script. continue - privkey, _ = self._get_key_from_path(path) - txinput.redeem_script = btc.pubkey_to_p2wpkh_script( - btc.privkey_to_pubkey(privkey)) + pubkey = self._get_pubkey_from_path(path)[0] + txinput.redeem_script = btc.pubkey_to_p2wpkh_script(pubkey) # no else branch; any other form of scriptPubKey will just be # ignored. try: @@ -1767,12 +1875,7 @@ def _load_storage(self): for md, keys in self._storage.data[self._IMPORTED_STORAGE_KEY].items(): md = int(md) self._imported[md] = keys - for index, (key, key_type) in enumerate(keys): - if not key: - # imported key was removed - continue - assert key_type in self._ENGINES - self._cache_imported_key(md, key, key_type, index) + self._populate_maps(self.yield_imported_paths(md)) def save(self): import_data = {} @@ -1841,8 +1944,8 @@ def remove_imported_key(self, script=None, address=None, path=None): raise Exception("Only one of script|address|path may be given.") if address: - script = self.addr_to_script(address) - if script: + path = self.addr_to_path(address) + elif script: path = self.script_to_path(script) if not path: @@ -1855,18 +1958,19 @@ def remove_imported_key(self, script=None, address=None, path=None): if not script: script = self.get_script_from_path(path) + if not address: + address = self.get_address_from_path(path) # we need to retain indices self._imported[path[1]][path[2]] = (b'', -1) del self._script_map[script] + del self._addr_map[address] + self._delete_cache_for_path(path) def _cache_imported_key(self, mixdepth, privkey, key_type, index): - engine = self._ENGINES[key_type] path = (self._IMPORTED_ROOT_PATH, mixdepth, index) - - self._script_map[engine.key_to_script(privkey)] = path - + self._populate_maps((path,)) return path def _get_mixdepth_from_path(self, path): @@ -1876,9 +1980,11 @@ def _get_mixdepth_from_path(self, path): assert len(path) == 3 return path[1] - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if not self._is_imported_path(path): - return super()._get_key_from_path(path) + return super()._get_key_from_path(path, + validate_cache=validate_cache) assert len(path) == 3 md, i = path[1], path[2] @@ -1901,7 +2007,7 @@ def _is_imported_path(cls, path): def is_standard_wallet_script(self, path): if self._is_imported_path(path): - engine = self._get_key_from_path(path)[1] + engine = self._get_pubkey_from_path(path)[1] return engine == self._ENGINE return super().is_standard_wallet_script(path) @@ -1932,13 +2038,6 @@ def get_details(self, path): return super().get_details(path) return path[1], 'imported', path[2] - def get_script_from_path(self, path): - if not self._is_imported_path(path): - return super().get_script_from_path(path) - - priv, engine = self._get_key_from_path(path) - return engine.key_to_script(priv) - class BIP39WalletMixin(object): """ @@ -2009,6 +2108,7 @@ class BIP32Wallet(BaseWallet): def __init__(self, storage, **kwargs): self._entropy = None + self._key_ident = None # {mixdepth: {type: index}} with type being 0/1 corresponding # to external/internal addresses self._index_cache = None @@ -2027,7 +2127,7 @@ def __init__(self, storage, **kwargs): # used to verify paths for sanity checking and for wallet id creation self._key_ident = b'' # otherwise get_bip32_* won't work self._key_ident = self._get_key_ident() - self._populate_script_map() + self._populate_maps(self.yield_known_bip32_paths()) self.disable_new_scripts = False @classmethod @@ -2073,13 +2173,14 @@ def _get_key_ident(self): self.get_bip32_priv_export(0, self.BIP32_EXT_ID).encode('ascii')).digest())\ .digest()[:3] - def _populate_script_map(self): + def yield_known_paths(self): + return chain(super().yield_known_paths(), self.yield_known_bip32_paths()) + + def yield_known_bip32_paths(self): for md in self._index_cache: for address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID): for i in range(self._index_cache[md][address_type]): - path = self.get_path(md, address_type, i) - script = self.get_script_from_path(path) - self._script_map[script] = path + yield self.get_path(md, address_type, i) def save(self): for md, data in self._index_cache.items(): @@ -2114,10 +2215,7 @@ def _derive_bip32_master_key(cls, seed): def _get_supported_address_types(cls): return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID) - def get_script_from_path(self, path): - if not self._is_my_bip32_path(path): - raise WalletError("unable to get script for unknown key path") - + def _check_path(self, path): md, address_type, index = self.get_details(path) if not 0 <= md <= self.max_mixdepth: @@ -2130,12 +2228,22 @@ def get_script_from_path(self, path): and address_type != FidelityBondMixin.BIP32_TIMELOCK_ID: #special case for timelocked addresses because for them the #concept of a "next address" cant be used - return self.get_new_script_override_disable(md, address_type) - - priv, engine = self._get_key_from_path(path) - script = engine.key_to_script(priv) - - return script + self._set_index_cache(md, address_type, current_index + 1) + self._populate_maps((path,)) + + def get_script_from_path(self, path, + validate_cache: bool = False): + if self._is_my_bip32_path(path): + self._check_path(path) + return super().get_script_from_path(path, + validate_cache=validate_cache) + + def get_address_from_path(self, path, + validate_cache: bool = False): + if self._is_my_bip32_path(path): + self._check_path(path) + return super().get_address_from_path(path, + validate_cache=validate_cache) def get_path(self, mixdepth=None, address_type=None, index=None): if mixdepth is not None: @@ -2151,7 +2259,6 @@ def get_path(self, mixdepth=None, address_type=None, index=None): assert isinstance(index, Integral) if address_type is None: raise Exception("address_type must be set if index is set") - assert index <= self._index_cache[mixdepth][address_type] assert index < self.BIP32_MAX_PATH_LEVEL return tuple(chain(self._get_bip32_export_path(mixdepth, address_type), (index,))) @@ -2200,30 +2307,62 @@ def _get_mixdepth_from_path(self, path): return path[len(self._get_bip32_base_path())] - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if not self._is_my_bip32_path(path): raise WalletError("Invalid path, unknown root: {}".format(path)) - - return self._ENGINE.derive_bip32_privkey(self._master_key, path), \ - self._ENGINE + cache = self._get_cache_for_path(path) + privkey = cache.get(b'p') + if privkey is None or validate_cache: + new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") + return privkey, self._ENGINE + + def _get_keypair_from_path(self, path, + validate_cache: bool = False): + if not self._is_my_bip32_path(path): + return super()._get_keypair_from_path(path, + validate_cache=validate_cache) + cache = self._get_cache_for_path(path) + privkey = cache.get(b'p') + if privkey is None or validate_cache: + new_privkey = self._ENGINE.derive_bip32_privkey(self._master_key, path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") + pubkey = cache.get(b'P') + if pubkey is None or validate_cache: + new_pubkey = self._ENGINE.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") + return privkey, pubkey, self._ENGINE + + def _get_cache_keys_for_path(self, path): + if not self._is_my_bip32_path(path): + return super()._get_cache_keys_for_path(path) + return path[:1] + tuple([self._path_level_to_repr(lvl).encode('ascii') + for lvl in path[1:]]) def _is_my_bip32_path(self, path): - return path[0] == self._key_ident + return len(path) > 0 and path[0] == self._key_ident def is_standard_wallet_script(self, path): return self._is_my_bip32_path(path) - def get_new_script(self, mixdepth, address_type): + def get_new_script(self, mixdepth, address_type, + validate_cache: bool = True): if self.disable_new_scripts: raise RuntimeError("Obtaining new wallet addresses " + "disabled, due to nohistory mode") - return self.get_new_script_override_disable(mixdepth, address_type) - - def get_new_script_override_disable(self, mixdepth, address_type): - # This is called by get_script_from_path and calls back there. We need to - # ensure all conditions match to avoid endless recursion. - index = self.get_index_cache_and_increment(mixdepth, address_type) - return self.get_script_and_update_map(mixdepth, address_type, index) + index = self._index_cache[mixdepth][address_type] + return self.get_script(mixdepth, address_type, index, + validate_cache=validate_cache) def _set_index_cache(self, mixdepth, address_type, index): """ Ensures that any update to index_cache dict only applies @@ -2232,22 +2371,6 @@ def _set_index_cache(self, mixdepth, address_type, index): assert address_type in self._get_supported_address_types() self._index_cache[mixdepth][address_type] = index - def get_index_cache_and_increment(self, mixdepth, address_type): - index = self._index_cache[mixdepth][address_type] - cur_index = self._index_cache[mixdepth][address_type] - self._set_index_cache(mixdepth, address_type, cur_index + 1) - return cur_index - - def get_script_and_update_map(self, *args): - path = self.get_path(*args) - script = self.get_script_from_path(path) - self._script_map[script] = path - return script - - def get_script(self, mixdepth, address_type, index): - path = self.get_path(mixdepth, address_type, index) - return self.get_script_from_path(path) - @deprecated def get_key(self, mixdepth, address_type, index): path = self.get_path(mixdepth, address_type, index) @@ -2392,6 +2515,10 @@ class FidelityBondMixin(object): _BIP32_PUBKEY_PREFIX = "fbonds-mpk-" + def __init__(self, storage, **kwargs): + super().__init__(storage, **kwargs) + self._populate_maps(self.yield_fidelity_bond_paths()) + @classmethod def _time_number_to_timestamp(cls, timenumber): """ @@ -2435,8 +2562,7 @@ def is_timelocked_path(cls, path): def _get_key_ident(self): first_path = self.get_path(0, BIP32Wallet.BIP32_EXT_ID) - priv, engine = self._get_key_from_path(first_path) - pub = engine.privkey_to_pubkey(priv) + pub = self._get_pubkey_from_path(first_path)[0] return sha256(sha256(pub).digest()).digest()[:3] def is_standard_wallet_script(self, path): @@ -2451,14 +2577,14 @@ def get_xpub_from_fidelity_bond_master_pub_key(cls, mpk): else: return False - def _populate_script_map(self): - super()._populate_script_map() + def yield_known_paths(self): + return chain(super().yield_known_paths(), self.yield_fidelity_bond_paths()) + + def yield_fidelity_bond_paths(self): md = self.FIDELITY_BOND_MIXDEPTH address_type = self.BIP32_TIMELOCK_ID for timenumber in range(self.TIMENUMBER_COUNT): - path = self.get_path(md, address_type, timenumber) - script = self.get_script_from_path(path) - self._script_map[script] = path + yield self.get_path(md, address_type, timenumber) def add_utxo(self, txid, index, script, value, height=None): super().add_utxo(txid, index, script, value, height) @@ -2482,16 +2608,54 @@ def get_bip32_pub_export(self, mixdepth=None, address_type=None): def _get_supported_address_types(cls): return (cls.BIP32_EXT_ID, cls.BIP32_INT_ID, cls.BIP32_TIMELOCK_ID, cls.BIP32_BURN_ID) - def _get_key_from_path(self, path): + def _get_key_from_path(self, path, + validate_cache: bool = False): if self.is_timelocked_path(path): key_path = path[:-1] locktime = path[-1] engine = self._TIMELOCK_ENGINE - privkey = engine.derive_bip32_privkey(self._master_key, key_path) + cache = super()._get_cache_for_path(key_path) + privkey = cache.get(b'p') + if privkey is None or validate_cache: + new_privkey = engine.derive_bip32_privkey(self._master_key, key_path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") return (privkey, locktime), engine else: return super()._get_key_from_path(path) + def _get_keypair_from_path(self, path, + validate_cache: bool = False): + if not self.is_timelocked_path(path): + return super()._get_keypair_from_path(path, + validate_cache=validate_cache) + key_path = path[:-1] + locktime = path[-1] + engine = self._TIMELOCK_ENGINE + cache = super()._get_cache_for_path(key_path) + privkey = cache.get(b'p') + if privkey is None or validate_cache: + new_privkey = engine.derive_bip32_privkey(self._master_key, key_path) + if privkey is None: + cache[b'p'] = privkey = new_privkey + elif privkey != new_privkey: + raise WalletError("Wallet cache validation failed") + pubkey = cache.get(b'P') + if pubkey is None or validate_cache: + new_pubkey = engine.privkey_to_pubkey(privkey) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") + return (privkey, locktime), (pubkey, locktime), engine + + def _get_cache_for_path(self, path): + if self.is_timelocked_path(path): + path = path[:-1] + return super()._get_cache_for_path(path) + def get_path(self, mixdepth=None, address_type=None, index=None): if address_type == None or address_type in (self.BIP32_EXT_ID, self.BIP32_INT_ID, self.BIP32_BURN_ID) or index == None: @@ -2537,14 +2701,6 @@ def get_details(self, path): def _get_default_used_indices(self): return {x: [0, 0, 0, 0] for x in range(self.max_mixdepth + 1)} - def get_script(self, mixdepth, address_type, index): - path = self.get_path(mixdepth, address_type, index) - return self.get_script_from_path(path) - - def get_addr(self, mixdepth, address_type, index): - script = self.get_script(mixdepth, address_type, index) - return self.script_to_addr(script) - def add_burner_output(self, path, txhex, block_height, merkle_branch, block_index, write=True): """ @@ -2644,6 +2800,43 @@ def _get_bip32_export_path(self, mixdepth=None, address_type=None): path = super()._get_bip32_export_path(mixdepth, address_type) return path + def _get_key_from_path(self, path, + validate_cache: bool = False): + raise WalletError("Cannot get a private key from a watch-only wallet") + + def _get_keypair_from_path(self, path, + validate_cache: bool = False): + raise WalletError("Cannot get a private key from a watch-only wallet") + + def _get_pubkey_from_path(self, path, + validate_cache: bool = False): + if not self._is_my_bip32_path(path): + return super()._get_pubkey_from_path(path, + validate_cache=validate_cache) + if self.is_timelocked_path(path): + key_path = path[:-1] + locktime = path[-1] + cache = self._get_cache_for_path(key_path) + pubkey = cache.get(b'P') + if pubkey is None or validate_cache: + new_pubkey = self._TIMELOCK_ENGINE.derive_bip32_privkey( + self._master_key, key_path) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") + return (pubkey, locktime), self._TIMELOCK_ENGINE + cache = self._get_cache_for_path(path) + pubkey = cache.get(b'P') + if pubkey is None or validate_cache: + new_pubkey = self._ENGINE.derive_bip32_privkey( + self._master_key, path) + if pubkey is None: + cache[b'P'] = pubkey = new_pubkey + elif pubkey != new_pubkey: + raise WalletError("Wallet cache validation failed") + return pubkey, self._ENGINE + WALLET_IMPLEMENTATIONS = { LegacyWallet.TYPE: LegacyWallet, diff --git a/src/jmclient/wallet_utils.py b/src/jmclient/wallet_utils.py index 3e29be2c5..786637fcf 100644 --- a/src/jmclient/wallet_utils.py +++ b/src/jmclient/wallet_utils.py @@ -7,7 +7,7 @@ from datetime import datetime, timedelta from optparse import OptionParser from numbers import Integral -from collections import Counter +from collections import Counter, defaultdict from itertools import islice, chain from jmclient import (get_network, WALLET_IMPLEMENTATIONS, Storage, podle, jm_single, WalletError, BaseWallet, VolatileStorage, @@ -403,15 +403,15 @@ def get_tx_info(txid, tx_cache=None): def get_imported_privkey_branch(wallet_service, m, showprivkey): entries = [] + balance_by_script = defaultdict(int) + for data in wallet_service.get_utxos_at_mixdepth(m, + include_disabled=True).values(): + balance_by_script[data['script']] += data['value'] for path in wallet_service.yield_imported_paths(m): addr = wallet_service.get_address_from_path(path) script = wallet_service.get_script_from_path(path) - balance = 0.0 - for data in wallet_service.get_utxos_by_mixdepth( - include_disabled=True)[m].values(): - if script == data['script']: - balance += data['value'] - status = ('used' if balance > 0.0 else 'empty') + balance = balance_by_script.get(script, 0) + status = ('used' if balance else 'empty') if showprivkey: wip_privkey = wallet_service.get_wif_path(path) else: @@ -431,9 +431,6 @@ def wallet_showutxos(wallet_service, showprivkey): includeconfs=True) for md in utxos: (enabled, disabled) = get_utxos_enabled_disabled(wallet_service, md) - utxo_d = [] - for k, v in disabled.items(): - utxo_d.append(k) for u, av in utxos[md].items(): success, us = utxo_to_utxostr(u) assert success @@ -453,7 +450,7 @@ def wallet_showutxos(wallet_service, showprivkey): 'external': False, 'mixdepth': mixdepth, 'confirmations': av['confs'], - 'frozen': True if u in utxo_d else False} + 'frozen': u in disabled} if showprivkey: unsp[us]['privkey'] = wallet_service.get_wif_path(av['path']) if locktime: @@ -1279,8 +1276,8 @@ def output_utxos(utxos, status, start=0): def get_utxos_enabled_disabled(wallet_service, md): """ Returns dicts for enabled and disabled separately """ - utxos_enabled = wallet_service.get_utxos_by_mixdepth()[md] - utxos_all = wallet_service.get_utxos_by_mixdepth(include_disabled=True)[md] + utxos_enabled = wallet_service.get_utxos_at_mixdepth(md) + utxos_all = wallet_service.get_utxos_at_mixdepth(md, include_disabled=True) utxos_disabled_keyset = set(utxos_all).difference(set(utxos_enabled)) utxos_disabled = {} for u in utxos_disabled_keyset: diff --git a/test/jmclient/test_taker.py b/test/jmclient/test_taker.py index 8d20a4331..7067382df 100644 --- a/test/jmclient/test_taker.py +++ b/test/jmclient/test_taker.py @@ -121,6 +121,12 @@ def get_txtype(self): """ return 'p2wpkh' + def _get_key_from_path(self, path, + validate_cache: bool = False): + if path[0] == b'dummy': + return struct.pack(b'B', path[2] + 1)*32 + b'\x01', self._ENGINE + raise NotImplementedError() + def get_key_from_addr(self, addr): """usable addresses: privkey all 1s, 2s, 3s, ... :""" privs = [x*32 + b"\x01" for x in [struct.pack(b'B', y) for y in range(1,6)]] @@ -139,18 +145,20 @@ def get_key_from_addr(self, addr): return p raise ValueError("No such keypair") - def _is_my_bip32_path(self, path): - return True + def get_path_repr(self, path): + return '/'.join(map(str, path)) def is_standard_wallet_script(self, path): if path[0] == "nonstandard_path": return False return True - def script_to_addr(self, script): + def script_to_addr(self, script, + validate_cache: bool = False): if self.script_to_path(script)[0] == "nonstandard_path": return "dummyaddr" - return super().script_to_addr(script) + return super().script_to_addr(script, + validate_cache=validate_cache) def dummy_order_chooser(): diff --git a/test/jmclient/test_utxomanager.py b/test/jmclient/test_utxomanager.py index 2d3023f14..1bd97e1ca 100644 --- a/test/jmclient/test_utxomanager.py +++ b/test/jmclient/test_utxomanager.py @@ -56,14 +56,12 @@ def test_utxomanager_persist(setup_env_nodeps): assert not um.is_disabled(txid, index+2) um.disable_utxo(txid, index+2) - utxos = um.get_utxos_by_mixdepth() - assert len(utxos[mixdepth]) == 1 - assert len(utxos[mixdepth+1]) == 2 - assert len(utxos[mixdepth+2]) == 0 + assert len(um.get_utxos_at_mixdepth(mixdepth)) == 1 + assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 2 + assert len(um.get_utxos_at_mixdepth(mixdepth+2)) == 0 - balances = um.get_balance_by_mixdepth() - assert balances[mixdepth] == value - assert balances[mixdepth+1] == value * 2 + assert um.get_balance_at_mixdepth(mixdepth) == value + assert um.get_balance_at_mixdepth(mixdepth+1) == value * 2 um.remove_utxo(txid, index, mixdepth) assert um.have_utxo(txid, index) == False @@ -79,14 +77,12 @@ def test_utxomanager_persist(setup_env_nodeps): assert um.have_utxo(txid, index) == False assert um.have_utxo(txid, index+1) == mixdepth + 1 - utxos = um.get_utxos_by_mixdepth() - assert len(utxos[mixdepth]) == 0 - assert len(utxos[mixdepth+1]) == 1 + assert len(um.get_utxos_at_mixdepth(mixdepth)) == 0 + assert len(um.get_utxos_at_mixdepth(mixdepth+1)) == 1 - balances = um.get_balance_by_mixdepth() - assert balances[mixdepth] == 0 - assert balances[mixdepth+1] == value - assert balances[mixdepth+2] == 0 + assert um.get_balance_at_mixdepth(mixdepth) == 0 + assert um.get_balance_at_mixdepth(mixdepth+1) == value + assert um.get_balance_at_mixdepth(mixdepth+2) == 0 def test_utxomanager_select(setup_env_nodeps): diff --git a/test/jmclient/test_wallet.py b/test/jmclient/test_wallet.py index 86d5d8e6b..45b23fa8e 100644 --- a/test/jmclient/test_wallet.py +++ b/test/jmclient/test_wallet.py @@ -17,7 +17,6 @@ wallet_gettimelockaddress, UnknownAddressForLabel from test_blockchaininterface import sync_test_wallet from freezegun import freeze_time -from bitcointx.wallet import CCoinAddressError pytestmark = pytest.mark.usefixtures("setup_regtest_bitcoind") @@ -264,9 +263,6 @@ def test_bip32_timelocked_addresses(setup_wallet, timenumber, address, wif): mixdepth = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH address_type = FidelityBondMixin.BIP32_TIMELOCK_ID - #wallet needs to know about the script beforehand - wallet.get_script_and_update_map(mixdepth, address_type, timenumber) - assert address == wallet.get_addr(mixdepth, address_type, timenumber) assert wif == wallet.get_wif_path(wallet.get_path(mixdepth, address_type, timenumber)) @@ -287,7 +283,7 @@ def test_gettimelockaddress_method(setup_wallet, timenumber, locktime_string): m = FidelityBondMixin.FIDELITY_BOND_MIXDEPTH address_type = FidelityBondMixin.BIP32_TIMELOCK_ID - script = wallet.get_script_and_update_map(m, address_type, timenumber) + script = wallet.get_script(m, address_type, timenumber) addr = wallet.script_to_addr(script) addr_from_method = wallet_gettimelockaddress(wallet, locktime_string) @@ -456,7 +452,7 @@ def test_timelocked_output_signing(setup_wallet): wallet = SegwitWalletFidelityBonds(storage) timenumber = 0 - script = wallet.get_script_and_update_map( + script = wallet.get_script( FidelityBondMixin.FIDELITY_BOND_MIXDEPTH, FidelityBondMixin.BIP32_TIMELOCK_ID, timenumber) utxo = fund_wallet_addr(wallet, wallet.script_to_addr(script)) @@ -477,7 +473,7 @@ def test_get_bbm(setup_wallet): wallet = get_populated_wallet(amount, num_tx) # disable a utxo and check we can correctly report # balance with the disabled flag off: - utxo_1 = list(wallet._utxos.get_utxos_by_mixdepth()[0].keys())[0] + utxo_1 = list(wallet._utxos.get_utxos_at_mixdepth(0).keys())[0] wallet.disable_utxo(*utxo_1) balances = wallet.get_balance_by_mixdepth(include_disabled=True) assert balances[0] == num_tx * amount @@ -610,7 +606,9 @@ def test_address_labels(setup_wallet): wallet.get_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4") wallet.set_address_label("2MzY5yyonUY7zpHspg7jB7WQs1uJxKafQe4", "test") - with pytest.raises(CCoinAddressError): + # we no longer decode addresses just to see if we know about them, + # so we won't get a CCoinAddressError for invalid addresses + #with pytest.raises(CCoinAddressError): wallet.get_address_label("badaddress") wallet.set_address_label("badaddress", "test")