diff --git a/src/embit/psbt.py b/src/embit/psbt.py index 54497c2..0d7b095 100644 --- a/src/embit/psbt.py +++ b/src/embit/psbt.py @@ -142,6 +142,7 @@ def __init__(self, unknown: dict = {}, vin=None, compress=CompressMode.KEEP_ALL) self.taproot_bip32_derivations = OrderedDict() self.taproot_internal_key = None self.taproot_merkle_root = None + self.taproot_key_sig = None self.taproot_sigs = OrderedDict() self.taproot_scripts = OrderedDict() @@ -187,6 +188,7 @@ def update(self, other): self.taproot_bip32_derivations.update(other.taproot_bip32_derivations) self.taproot_internal_key = other.taproot_internal_key self.taproot_merkle_root = other.taproot_merkle_root or self.taproot_merkle_root + self.taproot_key_sig = other.taproot_key_sig or self.taproot_key_sig self.taproot_sigs.update(other.taproot_sigs) self.taproot_scripts.update(other.taproot_scripts) self.final_scriptsig = other.final_scriptsig or self.final_scriptsig @@ -350,7 +352,15 @@ def read_value(self, stream, k): elif k == b"\x10": self.sequence = int.from_bytes(v, "little") - # TODO: 0x13 - tap key signature + # PSBT_IN_TAP_KEY_SIG + elif k[0] == 0x13: + # read the taproot key sig + if len(k) != 1: + raise PSBTError("Invalid taproot key signature key") + if self.taproot_key_sig is not None: + raise PSBTError("Duplicated taproot key signature") + self.taproot_key_sig = v + # PSBT_IN_TAP_SCRIPT_SIG elif k[0] == 0x14: if len(k) != 65: @@ -434,6 +444,11 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int: r += ser_string(stream, b"\x10") r += ser_string(stream, self.sequence.to_bytes(4, "little")) + # PSBT_IN_TAP_KEY_SIG + if self.taproot_key_sig is not None: + r += ser_string(stream, b"\x13") + r += ser_string(stream, self.taproot_key_sig) + # PSBT_IN_TAP_SCRIPT_SIG for pub, leaf in self.taproot_sigs: r += ser_string(stream, b"\x14" + pub.xonly() + leaf) @@ -881,11 +896,11 @@ def sign_input_with_tapkey( sighash=sighash, ) sig = pk.schnorr_sign(h) - wit = sig.serialize() + sigdata = sig.serialize() if sighash != SIGHASH.DEFAULT: - wit += bytes([sighash]) - # TODO: maybe better to put into internal key sig field - inp.final_scriptwitness = Witness([wit]) + sigdata += bytes([sighash]) + inp.taproot_key_sig = sigdata + inp.final_scriptwitness = Witness([sigdata]) # no need to sign anything else return 1 counter = 0