diff --git a/hwilib/_cli.py b/hwilib/_cli.py index c2a4ac28f..dad7753c8 100644 --- a/hwilib/_cli.py +++ b/hwilib/_cli.py @@ -12,6 +12,7 @@ getdescriptors, prompt_pin, toggle_passphrase, + registerpolicy, restore_device, send_pin, setup_device, @@ -57,7 +58,7 @@ def backup_device_handler(args: argparse.Namespace, client: HardwareWalletClient return backup_device(client, label=args.label, backup_passphrase=args.backup_passphrase) def displayaddress_handler(args: argparse.Namespace, client: HardwareWalletClient) -> Dict[str, str]: - return displayaddress(client, desc=args.desc, path=args.path, addr_type=args.addr_type) + return displayaddress(client, desc=args.desc, path=args.path, policy=args.policy, addr_type=args.addr_type, name=args.name, keys=args.keys, change=args.change, index=args.index, extra=args.extra) def enumerate_handler(args: argparse.Namespace) -> List[Dict[str, Any]]: return enumerate(password=args.password) @@ -74,6 +75,9 @@ def getkeypool_handler(args: argparse.Namespace, client: HardwareWalletClient) - def getdescriptors_handler(args: argparse.Namespace, client: HardwareWalletClient) -> Dict[str, List[str]]: return getdescriptors(client, account=args.account) +def registerpolicy_handler(args: argparse.Namespace, client: HardwareWalletClient) -> Dict[str, str]: + return registerpolicy(client, name=args.name, policy=args.policy, keys=args.keys, extra=args.extra) + def restore_device_handler(args: argparse.Namespace, client: HardwareWalletClient) -> Dict[str, bool]: if args.interactive: return restore_device(client, label=args.label, word_count=args.word_count) @@ -88,7 +92,7 @@ def signmessage_handler(args: argparse.Namespace, client: HardwareWalletClient) return signmessage(client, message=args.message, path=args.path) def signtx_handler(args: argparse.Namespace, client: HardwareWalletClient) -> Dict[str, Union[bool, str]]: - return signtx(client, psbt=args.psbt) + return signtx(client, psbt=args.psbt, name=args.name, policy=args.policy, keys=args.keys, extra=args.extra) def wipe_device_handler(args: argparse.Namespace, client: HardwareWalletClient) -> Dict[str, bool]: return wipe_device(client) @@ -160,6 +164,11 @@ def get_parser() -> HWIArgumentParser: signtx_parser = subparsers.add_parser('signtx', help='Sign a PSBT') signtx_parser.add_argument('psbt', help='The Partially Signed Bitcoin Transaction to sign') + signtx_parser.add_argument('--policy', help='The descriptor template of the wallet policy. E.g. wpkh(@0/**)', type=str) + signtx_parser.add_argument('--name', help='The name of the policy. E.g. "Cold storage"', type=str) + signtx_parser.add_argument('--keys', help='The list of keys in the wallet policy, encoded in JSON', type=str) + signtx_parser.add_argument('--extra', help='JSON encoded string; it might contain proof_of_registration, or vendor-specific fields', type=str, default="{}") + signtx_parser.set_defaults(func=signtx_handler) getxpub_parser = subparsers.add_parser('getxpub', help='Get an extended public key') @@ -192,8 +201,14 @@ def get_parser() -> HWIArgumentParser: displayaddr_parser = subparsers.add_parser('displayaddress', help='Display an address') group = displayaddr_parser.add_mutually_exclusive_group(required=True) group.add_argument('--desc', help='Output Descriptor. E.g. wpkh([00000000/84h/0h/0h]xpub.../0/0), where 00000000 must match --fingerprint and xpub can be obtained with getxpub. See doc/descriptors.md in Bitcoin Core') + group.add_argument('--policy', help='The descriptor template of the wallet policy. E.g. wpkh(@0/**)') group.add_argument('--path', help='The BIP 32 derivation path of the key embedded in the address, default follows BIP43 convention, e.g. ``m/84h/0h/0h/1/*``') displayaddr_parser.add_argument("--addr-type", help="The address type to display", type=AddressType.argparse, choices=list(AddressType), default=AddressType.WIT) # type: ignore + displayaddr_parser.add_argument('--name', help='The name of the policy. E.g. "Cold storage". Can be empty for default single-signature wallets') + displayaddr_parser.add_argument('--keys', help='The list of keys in the wallet policy, encoded in JSON') + displayaddr_parser.add_argument('--change', type=int, help='0 if not change, 1 if change', default=0) # TODO: can we use 'choices=' here? + displayaddr_parser.add_argument('--index', help='address index', type=int, default=0) + displayaddr_parser.add_argument('--extra', help='JSON encoded string; it might contain proof_of_registration, or vendor-specific fields', type=str, default="{}") displayaddr_parser.set_defaults(func=displayaddress_handler) setupdev_parser = subparsers.add_parser('setup', help='Setup a device. Passphrase protection uses the password given by -p. Requires interactive mode') @@ -204,6 +219,13 @@ def get_parser() -> HWIArgumentParser: wipedev_parser = subparsers.add_parser('wipe', help='Wipe a device') wipedev_parser.set_defaults(func=wipe_device_handler) + registerpolicy_parser = subparsers.add_parser('registerpolicy', help='Register a policy') + registerpolicy_parser.add_argument('--name', help='The name of the policy. E.g. "Cold storage"', type=str, required=True) + registerpolicy_parser.add_argument('--policy', help='The descriptor template of the wallet policy. E.g. wpkh(@0/**)', type=str, required=True) + registerpolicy_parser.add_argument('--keys', help='The list of keys in the wallet policy, encoded in JSON', type=str, required=True) + registerpolicy_parser.add_argument('--extra', help='JSON encoded string; it might contain proof_of_registration, or vendor-specific fields', type=str, default="{}") + registerpolicy_parser.set_defaults(func=registerpolicy_handler) + restore_parser = subparsers.add_parser('restore', help='Initiate the device restoring process. Requires interactive mode') restore_parser.add_argument('--word_count', '-w', help='Word count of your BIP39 recovery phrase (options: 12/18/24)', type=int, default=24) restore_parser.add_argument('--label', '-l', help='The name to give to the device', default='') diff --git a/hwilib/commands.py b/hwilib/commands.py index 318f883f9..5135f2432 100644 --- a/hwilib/commands.py +++ b/hwilib/commands.py @@ -21,6 +21,7 @@ import importlib import logging import platform +import json from ._base58 import xpub_to_pub_hex, xpub_to_xonly_pub_hex from .key import ( @@ -49,6 +50,7 @@ WPKHDescriptor, WSHDescriptor, ) +from .wallet_policy import WalletPolicy from .devices import __all__ as all_devs from .common import ( AddressType, @@ -182,19 +184,44 @@ def getmasterxpub(client: HardwareWalletClient, addrtype: AddressType = AddressT """ return {"xpub": client.get_master_xpub(addrtype, account).to_string()} -def signtx(client: HardwareWalletClient, psbt: str) -> Dict[str, Union[bool, str]]: +def signtx( + client: HardwareWalletClient, + psbt: str, + policy: Optional[str], + name: Optional[str], + keys: Optional[str], + extra: str = "{}" +) -> Dict[str, Union[bool, str]]: """ Sign a Partially Signed Bitcoin Transaction (PSBT) with the client. :param client: The client to interact with :param psbt: The PSBT to sign + :param policy: The descriptor template for the wallet policy to display the address for. Mutually exclusive with ``desc`` or ``path`` + :param name: The name of the wallet policy, if registered. Only works with ``policy``; if omitted, empty name is implied + :param keys: The list of keys information, as a JSON-encoded string. Only works with ``policy`` + :param extra: A JSON-encoded string representing a dictionary with any additional field for the hardware wallet :return: A dictionary containing the processed PSBT serialized in Base64. Returned as ``{"psbt": }``. """ # Deserialize the transaction tx = PSBT() tx.deserialize(psbt) - result = client.sign_tx(tx).serialize() + if policy is None: + result = client.sign_tx(tx).serialize() + else: + if keys is None: + raise BadArgumentError("--keys parameter is compulsory when using --policy") + + keys_info = json.loads(keys) + if not isinstance(keys_info, list) or any(not isinstance(k, str) for k in keys_info): + raise BadArgumentError("keys should be a json-encoded list of keys information") + + # TODO: could do more validation of each key origin info + + parsed_extra = json.loads(extra) + wp = WalletPolicy(name if name is not None else "", policy, keys_info, parsed_extra) + result = client.sign_tx_with_wallet_policy(tx, wp, parsed_extra).serialize() return {"psbt": result, "signed": result != psbt} def getxpub(client: HardwareWalletClient, path: str, expert: bool = False) -> Dict[str, Any]: @@ -436,20 +463,68 @@ def getdescriptors( return result + +def registerpolicy( + client: HardwareWalletClient, + policy: str, + name: str, + keys: str, + extra: str = "{}" +) -> Dict[str, str]: + """ + Display an address on the device for client. + The address can be specified by the path with additional parameters, or by a descriptor. + + :param client: The client to interact with + :param policy: The descriptor template for the wallet policy to display the address for. Mutually exclusive with ``desc`` or ``path`` + :param name: The name of the wallet policy, if registered. Only works with ``policy``; if omitted, empty name is implied + :param keys: The list of keys information, as a JSON-encoded string. Only works with ``policy`` + :param extra: A JSON-encoded string representing a dictionary with any additional field for the hardware wallet + :return: On success, returns the proof of registration if the hardware wallet requires it. + Returned as ``{"proof_of_registration": }``. The proof_of_registration is the empty string is the hardware + wallet does not return a proof of registration. + :raises: BadArgumentError: if an argument is malformed or missing. + """ + if name == "": + raise BadArgumentError("The policy name cannot be empty") + + keys_info = json.loads(keys) + if not isinstance(keys_info, list) or any(not isinstance(k, str) for k in keys_info): + raise BadArgumentError("keys should be a json-encoded list of keys information") + + # TODO: could do more validation of each key origin info + + parsed_extra = json.loads(extra) + wp = WalletPolicy(name if name is not None else "", policy, keys_info, parsed_extra) + return {"proof_of_registration": client.register_wallet_policy(wp, parsed_extra).hex()} + + def displayaddress( client: HardwareWalletClient, path: Optional[str] = None, desc: Optional[str] = None, - addr_type: AddressType = AddressType.WIT + policy: Optional[str] = None, + addr_type: AddressType = AddressType.WIT, + name: Optional[str] = None, + keys: Optional[str] = None, + change: Optional[int] = 0, + index: Optional[int] = 0, + extra: str = "{}" ) -> Dict[str, str]: """ Display an address on the device for client. The address can be specified by the path with additional parameters, or by a descriptor. :param client: The client to interact with - :param path: The path of the address to display. Mutually exclusive with ``desc`` - :param desc: The descriptor to display the address for. Mutually exclusive with ``path`` + :param path: The path of the address to display. Mutually exclusive with ``desc`` or ``policy`` + :param desc: The descriptor to display the address for. Mutually exclusive with ``path`` or ``policy`` + :param policy: The descriptor template for the wallet policy to display the address for. Mutually exclusive with ``desc`` or ``path`` :param addr_type: The address type to return. Only works with ``path`` + :param name: The name of the wallet policy, if registered. Only works with ``policy``; if omitted, empty name is implied + :param keys: The list of keys information, as a JSON-encoded string. Only works with ``policy`` + :param change: 0 if a normal receive address is desired, 1 for a change address. Only works with ``policy`` + :param index: The address index of the required address, a number between 0 and 2147483647 (include). Only works with ``policy`` + :param extra: A JSON-encoded string representing a dictionary with any additional field for the hardware wallet :return: A dictionary containing the address displayed. Returned as ``{"address": }``. :raises: BadArgumentError: if an argument is malformed, missing, or conflicts. @@ -491,7 +566,21 @@ def displayaddress( elif isinstance(descriptor, TRDescriptor): addr_type = AddressType.TAP return {"address": client.display_singlesig_address(pubkey.get_full_derivation_path(0), addr_type)} - raise BadArgumentError("Missing both path and descriptor") + elif policy is not None: + if keys is None: + raise BadArgumentError("--keys parameter is compulsory when using --policy") + + keys_info = json.loads(keys) + if not isinstance(keys_info, list) or any(not isinstance(k, str) for k in keys_info): + raise BadArgumentError("keys should be a json-encoded list of keys information") + + # TODO: could do more validation of each key origin info + + parsed_extra = json.loads(extra) + wp = WalletPolicy(name if name is not None else "", policy, keys_info, parsed_extra) + return {"address": client.display_wallet_policy_address(wp, change == 1, index, parsed_extra)} + + raise BadArgumentError("Missing all of path, descriptor and policy") def setup_device(client: HardwareWalletClient, label: str = "", backup_passphrase: str = "") -> Dict[str, bool]: """ diff --git a/hwilib/devices/ledger.py b/hwilib/devices/ledger.py index b7da4abc3..76ac4e75f 100644 --- a/hwilib/devices/ledger.py +++ b/hwilib/devices/ledger.py @@ -39,11 +39,8 @@ LegacyClient, TransportClient, ) +from .ledger_bitcoin.command_builder import get_wallet_policy_id from .ledger_bitcoin.exception import NotSupportedError -from .ledger_bitcoin.wallet import ( - MultisigWallet, - WalletPolicy, -) from .ledger_bitcoin.btchip.btchipException import BTChipException import builtins @@ -66,6 +63,7 @@ parse_multisig, ) from ..psbt import PSBT +from ..wallet_policy import WalletPolicy import logging import re @@ -138,6 +136,41 @@ def func(*args: Any, **kwargs: Any) -> Any: raise e return func +class MultisigWallet(WalletPolicy): + """Helper class to represent common multisignature wallet policies.""" + + def __init__(self, name: str, address_type: AddressType, threshold: int, keys_info: List[str], sorted: bool = True) -> None: + n_keys = len(keys_info) + + if not (1 <= threshold <= n_keys <= 16): + raise ValueError("Invalid threshold or number of keys") + + multisig_op = "sortedmulti" if sorted else "multi" + + if (address_type == AddressType.LEGACY): + policy_prefix = f"sh({multisig_op}(" + policy_suffix = "))" + elif address_type == AddressType.WIT: + policy_prefix = f"wsh({multisig_op}(" + policy_suffix = "))" + elif address_type == AddressType.SH_WIT: + policy_prefix = f"sh(wsh({multisig_op}(" + policy_suffix = ")))" + else: + raise ValueError(f"Unexpected address type: {address_type}") + + descriptor_template = "".join([ + policy_prefix, + str(threshold) + ",", + ",".join("@" + str(k) + "/**" for k in range(n_keys)), + policy_suffix + ]) + + super().__init__(name, descriptor_template, keys_info) + + self.threshold = threshold + + # This class extends the HardwareWalletClient for Ledger Nano S and Nano X specific things class LedgerClient(HardwareWalletClient): @@ -289,7 +322,7 @@ def legacy_sign_tx() -> PSBT: # Make and register the MultisigWallet msw = MultisigWallet(f"{k} of {len(key_exprs)} Multisig", script_addrtype, k, key_exprs) - msw_id = msw.id + msw_id = get_wallet_policy_id(msw) if msw_id not in wallets: _, registered_hmac = self.client.register_wallet(msw) wallets[msw_id] = ( @@ -304,7 +337,8 @@ def process_origin(origin: KeyOriginInfo) -> None: # TODO: Deal with non-default wallets return policy = self._get_singlesig_default_wallet_policy(script_addrtype, origin.path[2]) - wallets[policy.id] = ( + policy_id = get_wallet_policy_id(policy) + wallets[policy_id] = ( signing_priority[script_addrtype], script_addrtype, self._get_singlesig_default_wallet_policy(script_addrtype, origin.path[2]), @@ -538,6 +572,74 @@ def can_sign_taproot(self) -> bool: """ return isinstance(self.client, NewClient) + def can_register_wallet_policies(self) -> bool: + return True + + @ledger_exception + def register_wallet_policy(self, wallet_policy: WalletPolicy, extra: Dict[str, Any]) -> bytes: + _, hmac = self.client.register_wallet(wallet_policy) + return hmac + + @ledger_exception + def display_wallet_policy_address(self, wallet_policy: WalletPolicy, is_change: bool, address_index: int, extra: Dict[str, Any]) -> str: + # TODO: if non standard policy, should return an error if proof_of_registration is missing + if "proof_of_registration" not in extra: + hmac = None + else: + # TODO: error handling + hmac = bytes.fromhex(extra["proof_of_registration"]) + assert len(hmac) == 32 + + return self.client.get_wallet_address(wallet_policy, hmac, int(is_change), address_index, True) + + @ledger_exception + def sign_tx_with_wallet_policy(self, psbt: PSBT, wallet_policy: WalletPolicy, extra: Dict[str, Any]) -> PSBT: + # TODO: proof_of_registration not required for standard policies + if "proof_of_registration" not in extra: + raise BadArgumentError("Ledger Bitcoin app requires a proof_of_registration for non-standard wallet policies") + + # TODO: proper error handling + wallet_hmac = bytes.fromhex(extra["proof_of_registration"]) + + assert len(wallet_hmac) == 32 + + # Make a deepcopy of this psbt. We will need to modify it to get signing to work, + # which will affect the caller's detection for whether signing occured. + psbt2 = copy.deepcopy(psbt) + if psbt.version != 2: + psbt2.convert_to_v2() + + input_sigs = self.client.sign_psbt(psbt2, wallet_policy, wallet_hmac) + + for idx, pubkey, sig in input_sigs: + psbt_in = psbt2.inputs[idx] + + utxo = None + if psbt_in.witness_utxo: + utxo = psbt_in.witness_utxo + if psbt_in.non_witness_utxo: + assert psbt_in.prev_out is not None + utxo = psbt_in.non_witness_utxo.vout[psbt_in.prev_out] + assert utxo is not None + + is_wit, wit_ver, _ = utxo.is_witness() + + if is_wit and wit_ver >= 1: + # TODO: Deal with script path signatures + # For now, assume key path signature + psbt_in.tap_key_sig = sig + else: + psbt_in.partial_sigs[pubkey] = sig + + # Extract the sigs from psbt2 and put them into tx + for sig_in, psbt_in in zip(psbt2.inputs, psbt.inputs): + psbt_in.partial_sigs.update(sig_in.partial_sigs) + psbt_in.tap_script_sigs.update(sig_in.tap_script_sigs) + if len(sig_in.tap_key_sig) != 0 and len(psbt_in.tap_key_sig) == 0: + psbt_in.tap_key_sig = sig_in.tap_key_sig + + return psbt + def enumerate(password: str = '') -> List[Dict[str, Any]]: results = [] diff --git a/hwilib/devices/ledger_bitcoin/__init__.py b/hwilib/devices/ledger_bitcoin/__init__.py index 085351049..6d6d522d1 100644 --- a/hwilib/devices/ledger_bitcoin/__init__.py +++ b/hwilib/devices/ledger_bitcoin/__init__.py @@ -5,6 +5,4 @@ from .client import createClient from ...common import Chain -from .wallet import AddressType, WalletPolicy, MultisigWallet - -__all__ = ["Client", "TransportClient", "createClient", "Chain", "AddressType", "WalletPolicy", "MultisigWallet"] +__all__ = ["Client", "TransportClient", "createClient", "Chain"] diff --git a/hwilib/devices/ledger_bitcoin/client.py b/hwilib/devices/ledger_bitcoin/client.py index 64bbe3adc..7501ca9a1 100644 --- a/hwilib/devices/ledger_bitcoin/client.py +++ b/hwilib/devices/ledger_bitcoin/client.py @@ -2,14 +2,14 @@ import base64 from io import BytesIO, BufferedReader -from .command_builder import BitcoinCommandBuilder, BitcoinInsType +from .command_builder import BitcoinCommandBuilder, BitcoinInsType, serialize_wallet_policy from ...common import Chain from .client_command import ClientCommandInterpreter from .client_base import Client, TransportClient from .client_legacy import LegacyClient from .exception import DeviceException, NotSupportedError from .merkle import get_merkleized_map_commitment -from .wallet import WalletPolicy, WalletType +from ...wallet_policy import WalletPolicy from ...psbt import PSBT from ..._serialize import deser_string @@ -97,11 +97,8 @@ def get_extended_pubkey(self, path: str, display: bool = False) -> str: return response.decode() def register_wallet(self, wallet: WalletPolicy) -> Tuple[bytes, bytes]: - if wallet.version not in [WalletType.WALLET_POLICY_V1, WalletType.WALLET_POLICY_V2]: - raise ValueError("invalid wallet policy version") - client_intepreter = ClientCommandInterpreter() - client_intepreter.add_known_preimage(wallet.serialize()) + client_intepreter.add_known_preimage(serialize_wallet_policy(wallet)) client_intepreter.add_known_list([k.encode() for k in wallet.keys_info]) # necessary for version 1 of the protocol (available since version 2.1.0 of the app) @@ -131,15 +128,12 @@ def get_wallet_address( display: bool, ) -> str: - if not isinstance(wallet, WalletPolicy) or wallet.version not in [WalletType.WALLET_POLICY_V1, WalletType.WALLET_POLICY_V2]: - raise ValueError("wallet type must be WalletPolicy, with version either WALLET_POLICY_V1 or WALLET_POLICY_V2") - if change != 0 and change != 1: raise ValueError("Invalid change") client_intepreter = ClientCommandInterpreter() client_intepreter.add_known_list([k.encode() for k in wallet.keys_info]) - client_intepreter.add_known_preimage(wallet.serialize()) + client_intepreter.add_known_preimage(serialize_wallet_policy(wallet)) # necessary for version 1 of the protocol (available since version 2.1.0 of the app) client_intepreter.add_known_preimage(wallet.descriptor_template.encode()) @@ -197,7 +191,7 @@ def sign_psbt(self, psbt: PSBT, wallet: WalletPolicy, wallet_hmac: Optional[byte client_intepreter = ClientCommandInterpreter() client_intepreter.add_known_list([k.encode() for k in wallet.keys_info]) - client_intepreter.add_known_preimage(wallet.serialize()) + client_intepreter.add_known_preimage(serialize_wallet_policy(wallet)) # necessary for version 1 of the protocol (available since version 2.1.0 of the app) client_intepreter.add_known_preimage(wallet.descriptor_template.encode()) diff --git a/hwilib/devices/ledger_bitcoin/client_base.py b/hwilib/devices/ledger_bitcoin/client_base.py index 5b846963f..81dc293d0 100644 --- a/hwilib/devices/ledger_bitcoin/client_base.py +++ b/hwilib/devices/ledger_bitcoin/client_base.py @@ -8,7 +8,7 @@ from .command_builder import DefaultInsType from .exception import DeviceException -from .wallet import WalletPolicy +from ...wallet_policy import WalletPolicy from ...psbt import PSBT from ..._serialize import deser_string diff --git a/hwilib/devices/ledger_bitcoin/client_legacy.py b/hwilib/devices/ledger_bitcoin/client_legacy.py index b545fe593..d9b7f72df 100644 --- a/hwilib/devices/ledger_bitcoin/client_legacy.py +++ b/hwilib/devices/ledger_bitcoin/client_legacy.py @@ -17,7 +17,7 @@ from ...common import AddressType, Chain, hash160 from ...key import ExtendedKey, parse_path from ...psbt import PSBT -from .wallet import WalletPolicy +from ...wallet_policy import WalletPolicy from ..._script import is_p2sh, is_witness, is_p2wpkh, is_p2wsh diff --git a/hwilib/devices/ledger_bitcoin/command_builder.py b/hwilib/devices/ledger_bitcoin/command_builder.py index 30a804e1a..016fcd626 100644 --- a/hwilib/devices/ledger_bitcoin/command_builder.py +++ b/hwilib/devices/ledger_bitcoin/command_builder.py @@ -1,13 +1,19 @@ import enum from typing import List, Tuple, Mapping, Union, Iterator, Optional +from hashlib import sha256 + from ..._serialize import ser_compact_size as write_varint from .merkle import get_merkleized_map_commitment, MerkleTree, element_hash -from .wallet import WalletPolicy +from ...wallet_policy import WalletPolicy # p2 encodes the protocol version implemented CURRENT_PROTOCOL_VERSION = 1 +# version number in Ledger's serialization of wallet policies +CURRENT_WALLET_POLICY_VERSION = 2 + + def bip32_path_from_string(path: str) -> List[bytes]: splitted_path: List[str] = path.split("/") @@ -41,6 +47,23 @@ def chunkify(data: bytes, chunk_len: int) -> Iterator[Tuple[bool, bytes]]: yield True, data[offset:] +def serialize_str(value: str) -> bytes: + return len(value).to_bytes(1, byteorder="big") + value.encode("latin-1") + +def serialize_wallet_policy(wallet_policy: WalletPolicy) -> bytes: + keys_info_hashes = map(lambda k: element_hash(k.encode()), wallet_policy.keys_info) + return b"".join([ + CURRENT_WALLET_POLICY_VERSION.to_bytes(1, byteorder="big"), + serialize_str(wallet_policy.name), + write_varint(len(wallet_policy.descriptor_template.encode())), + sha256(wallet_policy.descriptor_template.encode()).digest(), + write_varint(len(wallet_policy.keys_info)), + MerkleTree(keys_info_hashes).root + ]) + +def get_wallet_policy_id(wallet_policy: WalletPolicy) -> bytes: + return sha256(serialize_wallet_policy(wallet_policy)).digest() + class DefaultInsType(enum.IntEnum): GET_VERSION = 0x01 @@ -112,7 +135,7 @@ def get_extended_pubkey(self, bip32_path: List[int], display: bool = False): ) def register_wallet(self, wallet: WalletPolicy): - wallet_bytes = wallet.serialize() + wallet_bytes = serialize_wallet_policy(wallet) return self.serialize( cla=self.CLA_BITCOIN, @@ -131,7 +154,7 @@ def get_wallet_address( cdata: bytes = b"".join( [ b'\1' if display else b'\0', # 1 byte - wallet.id, # 32 bytes + get_wallet_policy_id(wallet), # 32 bytes wallet_hmac if wallet_hmac is not None else b'\0' * 32, # 32 bytes b"\1" if change else b"\0", # 1 byte address_index.to_bytes(4, byteorder="big"), # 4 bytes @@ -172,7 +195,7 @@ def sign_psbt( ] ).root - cdata += wallet.id + cdata += get_wallet_policy_id(wallet) cdata += wallet_hmac if wallet_hmac is not None else b'\0' * 32 return self.serialize( diff --git a/hwilib/devices/ledger_bitcoin/wallet.py b/hwilib/devices/ledger_bitcoin/wallet.py deleted file mode 100644 index ff732abd3..000000000 --- a/hwilib/devices/ledger_bitcoin/wallet.py +++ /dev/null @@ -1,128 +0,0 @@ -import re - -from enum import IntEnum -from typing import List - -from hashlib import sha256 - -from ...common import AddressType -from .merkle import MerkleTree, element_hash -from ..._serialize import ser_compact_size as write_varint - - -def serialize_str(value: str) -> bytes: - return len(value).to_bytes(1, byteorder="big") + value.encode("latin-1") - - -class WalletType(IntEnum): - WALLET_POLICY_V1 = 1 - WALLET_POLICY_V2 = 2 - - -# should not be instantiated directly -class WalletPolicyBase: - def __init__(self, name: str, version: WalletType) -> None: - self.name = name - self.version = version - - if (version != WalletType.WALLET_POLICY_V1 and version != WalletType.WALLET_POLICY_V2): - raise ValueError("Invalid wallet policy version") - - def serialize(self) -> bytes: - return b"".join([ - self.version.value.to_bytes(1, byteorder="big"), - serialize_str(self.name) - ]) - - @property - def id(self) -> bytes: - return sha256(self.serialize()).digest() - - -class WalletPolicy(WalletPolicyBase): - """ - Represents a wallet stored with a wallet policy. - For version V2, the wallet is serialized as follows: - - 1 byte : wallet version - - 1 byte : length of the wallet name (max 64) - - (var) : wallet name (ASCII string) - - (varint) : length of the descriptor template - - 32-bytes : sha256 hash of the descriptor template - - (varint) : number of keys (not larger than 252) - - 32-bytes : root of the Merkle tree of all the keys information. - - The specific format of the keys is deferred to subclasses. - """ - - def __init__(self, name: str, descriptor_template: str, keys_info: List[str], version: WalletType = WalletType.WALLET_POLICY_V2): - super().__init__(name, version) - self.descriptor_template = descriptor_template - self.keys_info = keys_info - - @property - def n_keys(self) -> int: - return len(self.keys_info) - - def serialize(self) -> bytes: - keys_info_hashes = map(lambda k: element_hash(k.encode()), self.keys_info) - - descriptor_template_sha256 = sha256(self.descriptor_template.encode()).digest() - - return b"".join([ - super().serialize(), - write_varint(len(self.descriptor_template.encode())), - self.descriptor_template.encode() if self.version == WalletType.WALLET_POLICY_V1 else descriptor_template_sha256, - write_varint(len(self.keys_info)), - MerkleTree(keys_info_hashes).root - ]) - - def get_descriptor(self, change: bool) -> str: - desc = self.descriptor_template - for i in reversed(range(self.n_keys)): - key = self.keys_info[i] - desc = desc.replace(f"@{i}", key) - - # in V1, /** is part of the key; in V2, it's part of the policy map. This handles either - desc = desc.replace("/**", f"/{1 if change else 0}/*") - - if self.version == WalletType.WALLET_POLICY_V2: - # V2, the / syntax is supported. Replace with M if not change, or with N if change - regex = r"/<(\d+);(\d+)>" - desc = re.sub(regex, "/\\2" if change else "/\\1", desc) - - return desc - - -class MultisigWallet(WalletPolicy): - def __init__(self, name: str, address_type: AddressType, threshold: int, keys_info: List[str], sorted: bool = True, version: WalletType = WalletType.WALLET_POLICY_V2) -> None: - n_keys = len(keys_info) - - if not (1 <= threshold <= n_keys <= 16): - raise ValueError("Invalid threshold or number of keys") - - multisig_op = "sortedmulti" if sorted else "multi" - - if (address_type == AddressType.LEGACY): - policy_prefix = f"sh({multisig_op}(" - policy_suffix = "))" - elif address_type == AddressType.WIT: - policy_prefix = f"wsh({multisig_op}(" - policy_suffix = "))" - elif address_type == AddressType.SH_WIT: - policy_prefix = f"sh(wsh({multisig_op}(" - policy_suffix = ")))" - else: - raise ValueError(f"Unexpected address type: {address_type}") - - key_placeholder_suffix = "/**" if version == WalletType.WALLET_POLICY_V2 else "" - - descriptor_template = "".join([ - policy_prefix, - str(threshold) + ",", - ",".join("@" + str(k) + key_placeholder_suffix for k in range(n_keys)), - policy_suffix - ]) - - super().__init__(name, descriptor_template, keys_info, version) - - self.threshold = threshold diff --git a/hwilib/hwwclient.py b/hwilib/hwwclient.py index 99092273d..d279df1a2 100644 --- a/hwilib/hwwclient.py +++ b/hwilib/hwwclient.py @@ -6,11 +6,14 @@ """ from typing import ( + Any, Dict, Optional, Union, ) -from .descriptor import MultisigDescriptor + +from .wallet_policy import WalletPolicy +from .descriptor import MultisigDescriptor, parse_descriptor from .key import ( ExtendedKey, get_bip44_purpose, @@ -237,3 +240,106 @@ def can_sign_taproot(self) -> bool: """ raise NotImplementedError("The HardwareWalletClient base class " "does not implement this method") + + def can_register_wallet_policies(self) -> bool: + """ + Whether the device can register wallet policies + + :return: Whether wallet policies are supported + """ + return False + + def register_wallet_policy(self, wallet_policy: WalletPolicy, extra: Dict[str, Any]) -> bytes: + """ + Registers a wallet policy on the device. + + :param wallet_policy: The WalletPolicy to register + :param extra: A dictionary with any additional parameters needed from the hardware wallet + :return: A binary string to be used as proof_of_registration if required from the hardware wallet, or + an empty binary string b'' if the hardware wallet does not require a proof of registration + """ + raise NotImplementedError("The HardwareWalletClient base class " + "does not implement this method") + + def display_wallet_policy_address(self, wallet_policy: WalletPolicy, is_change: bool, address_index: int, extra: Dict[str, Any]) -> str: + """ + Display and return the specified address for the wallet policy. + + :param wallet_policy: The WalletPolicy to be used to show an address from + :param is_change: False if a normal receive address is shown, True for a change address + :param address_index: The index of the address to show. It must be at least 0 and at most 2147483647. + :param extra: Any additional parameters needed from the hardware wallet + :return: The retrieved address also being shown by the device + """ + + # TODO: Default behavior: if the wallet policy is a standard single signature address, use display_singlesig_address; + # if standard multisig, use display_multisig_address; otherwise, fail. + + fpr = self.get_master_fingerprint() + + if len(wallet_policy.keys_info) == 1 and wallet_policy.keys_info[0][0] == '[': + + key_info = wallet_policy.keys_info[0] + key_fingerprint = bytes.fromhex(key_info[1:9]) + + assert key_info.find(']') != -1 + + if fpr == key_fingerprint: + if wallet_policy.descriptor_template in ["pkh(@0/**)", "pkh(@0/<0;1>/*)"]: + addr_type = AddressType.LEGACY + elif wallet_policy.descriptor_template in ["wpkh(@0/**)", "wpkh(@0/<0;1>/*)"]: + addr_type = AddressType.WIT + elif wallet_policy.descriptor_template in ["sh(wpkh(@0/**))", "sh(wpkh(@0/<0;1>/*))"]: + addr_type = AddressType.SH_WIT + elif wallet_policy.descriptor_template in ["tr(@0/**)", "tr(@0/<0;1>/*)"]: + addr_type = AddressType.TAP + else: + raise NotImplementedError("The HardwareWalletClient base class " + "does not implement this method") + + # TODO we're assuming here that if the fingerprint matches, it's an internal input. + # we might want to compare the derived pubkey and compare in order to be sure + + key_origin_path = key_info[10:key_info.find(']')] + bip32_path = f"m/{key_origin_path}/{int(is_change)}/{address_index}" + + return self.display_singlesig_address(bip32_path, addr_type) + else: + desc = wallet_policy.to_descriptor() + desc = desc.replace("/<0;1>/*", f"/{int(is_change)}/{address_index}") + + addr_type: Optional[AddressType] = None + + if desc.startswith("sh(wsh("): + multisig_desc = desc[len("sh(wsh("):-2] + addr_type = AddressType.SH_WIT + elif desc.startswith("wsh("): + multisig_desc = desc[len("wsh("):-1] + addr_type = AddressType.WIT + elif desc.startswith("sh("): + multisig_desc = desc[len("sh("):-1] + addr_type = AddressType.LEGACY + + if addr_type is not None: + descriptor = parse_descriptor(multisig_desc) + return self.display_multisig_address(addr_type, descriptor) + + raise NotImplementedError("The HardwareWalletClient base class does not " + "implement this method for the given parameters") + + def sign_tx_with_wallet_policy(self, psbt: PSBT, wallet_policy: WalletPolicy, extra: Dict[str, Any]) -> PSBT: + """ + Sign a partially signed bitcoin transaction (PSBT) using the specified wallet policy. + Inputs that do not belong to the wallet policy (or with insufficient information in the PSBT) are not signed. + + :param psbt: The PSBT to sign + :param wallet_policy: The WalletPolicy to be used for signing + :param extra: A dictionary with any additional parameters needed from the hardware wallet + :return: The PSBT after being processed by the hardware wallet + """ + + # TODO: Default behavior: fail if not standard wallet policy; then, sign using sign_tx, but only retain signatures + # for inputs that match the wallet policy + + # TODO: just for initial testing, yolo-sign everything + return self.sign_tx(psbt) diff --git a/hwilib/wallet_policy.py b/hwilib/wallet_policy.py new file mode 100644 index 000000000..6a4ec3d1e --- /dev/null +++ b/hwilib/wallet_policy.py @@ -0,0 +1,30 @@ +from typing import Any, Dict, List + +class WalletPolicy(object): + """Simple class to represent wallet policies.""" + + def __init__(self, name: str, descriptor_template: str, keys_info: List[str], extra: Dict[str, Any] = {}): + """TODO: document constructor arguments""" + self.name = name + self.descriptor_template = descriptor_template + self.keys_info = keys_info + self.extra = extra + + def to_descriptor(self) -> str: + """Converts a wallet policy into the descriptor (with the / syntax, if present).""" + + desc = self.descriptor_template + + # replace each "/**" with "/<0;1>/*" + desc = desc.replace("/**", "/<0;1>/*") + + # process all the @N expressions in decreasing order. This guarantees that string replacements + # works as expected (as any prefix expression is processed after). + for i in reversed(range(len(self.keys_info))): + desc = desc.replace(f"@{i}", self.keys_info[i]) + + # there should not be any remaining "@" expressions + if desc.find("@") != -1: + return Exception("Invalid descriptor template: contains invalid key index") + + return desc