From 2be3999c09f74509532ca6d677aba08df9fc1451 Mon Sep 17 00:00:00 2001 From: odudex Date: Tue, 4 Jun 2024 13:21:50 -0300 Subject: [PATCH 1/5] replace assertions by conditional error raising --- src/embit/bip85.py | 12 ++++-- src/embit/descriptor/arguments.py | 6 ++- src/embit/ec.py | 6 ++- src/embit/liquid/pset.py | 9 +++-- src/embit/liquid/psetview.py | 3 +- src/embit/liquid/transaction.py | 27 +++++++++----- src/embit/misc.py | 3 +- src/embit/psbtview.py | 6 ++- src/embit/util/ctypes_secp256k1.py | 21 +++++++---- src/embit/util/key.py | 60 ++++++++++++++++++++---------- src/embit/util/py_ripemd160.py | 2 +- src/embit/util/py_secp256k1.py | 15 +++++--- 12 files changed, 113 insertions(+), 57 deletions(-) diff --git a/src/embit/bip85.py b/src/embit/bip85.py index 5867abd..6b74dc0 100644 --- a/src/embit/bip85.py +++ b/src/embit/bip85.py @@ -17,7 +17,8 @@ def derive_entropy(root, app_index, path): """ Derive app-specific bip85 entropy using path m/83696968'/app_index'/...path' """ - assert max(path) < HARDENED_INDEX + if max(path) >= HARDENED_INDEX: + raise ValueError("Path elements must be less than 2^31") derivation = [HARDENED_INDEX + 83696968, HARDENED_INDEX + app_index] + [ p + HARDENED_INDEX for p in path ] @@ -27,7 +28,8 @@ def derive_entropy(root, app_index, path): def derive_mnemonic(root, num_words=12, index=0, language=LANGUAGES.ENGLISH): """Derive a new mnemonic with num_words using language (code, wordlist)""" - assert num_words in [12, 18, 24] + if num_words not in [12, 18, 24]: + raise ValueError("Number of words must be 12, 18 or 24") langcode, wordlist = language path = [langcode, num_words, index] entropy = derive_entropy(root, 39, path) @@ -49,7 +51,9 @@ def derive_xprv(root, index=0): def derive_hex(root, num_bytes=32, index=0): """Derive raw entropy from 16 to 64 bytes long""" - assert num_bytes <= 64 - assert num_bytes >= 16 + if num_bytes > 64: + raise ValueError("Number of bytes must be less than 64") + if num_bytes < 16: + raise ValueError("Number of bytes must be at least 16") entropy = derive_entropy(root, 128169, [num_bytes, index]) return entropy[:num_bytes] diff --git a/src/embit/descriptor/arguments.py b/src/embit/descriptor/arguments.py index 3f92500..0c8d4d5 100644 --- a/src/embit/descriptor/arguments.py +++ b/src/embit/descriptor/arguments.py @@ -15,7 +15,8 @@ def __init__(self, fingerprint: bytes, derivation: list): def from_string(cls, s: str): arr = s.split("/") mfp = unhexlify(arr[0]) - assert len(mfp) == 4 + if len(mfp) != 4: + raise ArgumentError("Invalid fingerprint length") arr[0] = "m" path = "/".join(arr) derivation = bip32.parse_path(path) @@ -315,7 +316,8 @@ def xonly(self): return self.key.xonly() def taproot_tweak(self, h=b""): - assert self.taproot + if not self.taproot: + raise ArgumentError("Key is not taproot") return self.key.taproot_tweak(h) def serialize(self): diff --git a/src/embit/ec.py b/src/embit/ec.py index a93fc71..b36d861 100644 --- a/src/embit/ec.py +++ b/src/embit/ec.py @@ -26,7 +26,8 @@ def read_from(cls, stream): class SchnorrSig(EmbitBase): def __init__(self, sig): - assert len(sig) == 64 + if len(sig) != 64: + raise ECError("Invalid schnorr signature") self._sig = sig def write_to(self, stream) -> int: @@ -93,7 +94,8 @@ def _xonly(self): @classmethod def from_xonly(cls, data: bytes): - assert len(data) == 32 + if len(data) != 32: + raise ECError("Invalid xonly pubkey") return cls.parse(b"\x02" + data) def schnorr_verify(self, sig, msg_hash) -> bool: diff --git a/src/embit/liquid/pset.py b/src/embit/liquid/pset.py index 3a2c8b8..290c59a 100644 --- a/src/embit/liquid/pset.py +++ b/src/embit/liquid/pset.py @@ -96,9 +96,11 @@ def unblind(self, blinding_key): return # verify gen = secp256k1.generator_generate_blinded(asset, in_abf) - assert gen == secp256k1.generator_parse(self.utxo.asset) + if gen != secp256k1.generator_parse(self.utxo.asset): + raise PSBTError("Invalid asset commitment") cmt = secp256k1.pedersen_commit(vbf, value, gen) - assert cmt == secp256k1.pedersen_commitment_parse(self.utxo.value) + if cmt != secp256k1.pedersen_commitment_parse(self.utxo.value): + raise PSBTError("Invalid value commitment") self.asset = asset self.value = value @@ -506,7 +508,8 @@ def unblind(self, blinding_key): inp.unblind(blinding_key) def txseed(self, seed: bytes): - assert len(seed) == 32 + if len(seed) != 32: + raise PSBTError("Seed should be 32 bytes") # get unique seed for this tx: # we use seed + txid:vout + scriptpubkey as unique data for tagged hash data = b"".join( diff --git a/src/embit/liquid/psetview.py b/src/embit/liquid/psetview.py index dbedbbd..70dda48 100644 --- a/src/embit/liquid/psetview.py +++ b/src/embit/liquid/psetview.py @@ -5,7 +5,8 @@ def skip_commitment(stream): c = stream.read(1) - assert len(c) == 1 + if len(c) != 1: + raise PSBTError("Unexpected end of stream") if c == b"\x00": # None return 1 if c == b"\x01": # unconfidential diff --git a/src/embit/liquid/transaction.py b/src/embit/liquid/transaction.py index b64dd02..ca7af0a 100644 --- a/src/embit/liquid/transaction.py +++ b/src/embit/liquid/transaction.py @@ -40,16 +40,19 @@ class LTransactionError(TransactionError): def read_commitment(stream): c = stream.read(1) - assert len(c) == 1 + if len(c) != 1: + raise TransactionError("Invalid commitment") if c == b"\x00": # None return None if c == b"\x01": # unconfidential r = stream.read(8) - assert len(r) == 8 + if len(r) != 8: + raise TransactionError("Invalid commitment") return int.from_bytes(r, "big") # confidential r = stream.read(32) - assert len(r) == 32 + if len(r) != 32: + raise TransactionError("Invalid commitment") return c + r @@ -71,10 +74,14 @@ def unblind( message_length=64, ) -> tuple: """Unblinds a range proof and returns value, asset, value blinding factor, asset blinding factor, extra data, min and max values""" - assert len(pubkey) in [33, 65] - assert len(blinding_key) == 32 - assert len(value_commitment) == 33 - assert len(asset_commitment) == 33 + if len(pubkey) not in [33, 65]: + raise TransactionError("Invalid pubkey length") + if len(blinding_key) != 32: + raise TransactionError("Invalid blinding key length") + if len(value_commitment) != 33: + raise TransactionError("Invalid value commitment length") + if len(asset_commitment) != 33: + raise TransactionError("Invalid asset commitment length") pub = secp256k1.ec_pubkey_parse(pubkey) secp256k1.ec_pubkey_tweak_mul(pub, blinding_key) sec = secp256k1.ec_pubkey_serialize(pub) @@ -397,9 +404,11 @@ def __init__(self, nonce, entropy, amount_commitment, token_commitment=None): @classmethod def read_from(cls, stream): nonce = stream.read(32) - assert len(nonce) == 32 + if len(nonce) != 32: + raise TransactionError("Invalid nonce") entropy = stream.read(32) - assert len(entropy) == 32 + if len(entropy) != 32: + raise TransactionError("Invalid entropy") amount_commitment = read_commitment(stream) token_commitment = read_commitment(stream) return cls(nonce, entropy, amount_commitment, token_commitment) diff --git a/src/embit/misc.py b/src/embit/misc.py index fc2c804..97669ad 100644 --- a/src/embit/misc.py +++ b/src/embit/misc.py @@ -35,7 +35,8 @@ def secure_randint(vmin: int, vmax: int) -> int: """ import math - assert vmax > vmin + if vmax <= vmin: + raise ValueError("vmax must be greater than vmin") delta = vmax - vmin nbits = math.ceil(math.log2(delta + 1)) randn = getrandbits(nbits) diff --git a/src/embit/psbtview.py b/src/embit/psbtview.py index 8012654..1d65b1d 100644 --- a/src/embit/psbtview.py +++ b/src/embit/psbtview.py @@ -239,8 +239,10 @@ def view(cls, stream, offset=None, compress=CompressMode.KEEP_ALL): num_outputs = compact.from_bytes(value) elif key == b"\x00": # we found global transaction - assert version != 2 - assert (num_inputs is None) and (num_outputs is None) + if version == 2: + raise PSBTError("Global transaction with version 2 PSBT") + if (num_inputs is not None) or (num_outputs is not None): + raise PSBTError("Invalid global transaction") tx_len = compact.read_from(stream) cur += len(compact.to_bytes(tx_len)) tx_offset = cur diff --git a/src/embit/util/ctypes_secp256k1.py b/src/embit/util/ctypes_secp256k1.py index 232abda..8068355 100644 --- a/src/embit/util/ctypes_secp256k1.py +++ b/src/embit/util/ctypes_secp256k1.py @@ -761,16 +761,20 @@ def xonly_pubkey_from_pubkey(pubkey, context=_secp.ctx): @locked def schnorrsig_verify(sig, msg, pubkey, context=_secp.ctx): - assert len(sig) == 64 - assert len(msg) == 32 - assert len(pubkey) == 64 + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + if len(msg) != 32: + raise ValueError("Message should be 32 bytes long") + if len(pubkey) != 64: + raise ValueError("Public key should be 64 bytes long") res = _secp.secp256k1_schnorrsig_verify(context, sig, msg, pubkey) return bool(res) @locked def keypair_create(secret, context=_secp.ctx): - assert len(secret) == 32 + if len(secret) != 32: + raise ValueError("Secret key should be 32 bytes long") keypair = bytes(96) r = _secp.secp256k1_keypair_create(context, keypair, secret) if r == 0: @@ -782,11 +786,13 @@ def keypair_create(secret, context=_secp.ctx): def schnorrsig_sign( msg, keypair, nonce_function=None, extra_data=None, context=_secp.ctx ): - assert len(msg) == 32 + if len(msg) != 32: + raise ValueError("Message should be 32 bytes long") if len(keypair) == 32: keypair = keypair_create(keypair, context=context) with _lock: - assert len(keypair) == 96 + if len(keypair) != 96: + raise ValueError("Keypair should be 96 bytes long") sig = bytes(64) r = _secp.secp256k1_schnorrsig_sign( context, sig, msg, keypair, nonce_function, extra_data @@ -916,7 +922,8 @@ def pedersen_blind_generator_blind_sum( if res == 0: raise ValueError("Failed to get the last blinding factor.") res = (c_char * 32).from_address(address).raw - assert len(res) == 32 + if len(res) != 32: + raise ValueError("Blinding factor should be 32 bytes long") return res diff --git a/src/embit/util/key.py b/src/embit/util/key.py index 13b01d9..f573ec5 100644 --- a/src/embit/util/key.py +++ b/src/embit/util/key.py @@ -43,7 +43,8 @@ def jacobi_symbol(n, k): For our application k is always prime, so this is the same as the Legendre symbol. """ - assert k > 0 and k & 1, "jacobi symbol is only defined for positive odd k" + if k <= 0 or k % 2 == 0: + raise ValueError("jacobi symbol is only defined for positive odd k") n %= k t = 0 while n != 0: @@ -165,7 +166,8 @@ def add_mixed(self, p1, p2): """ x1, y1, z1 = p1 x2, y2, z2 = p2 - assert z2 == 1 + if z2 != 1: + raise ValueError("p2 must be an affine point") # Adding to the point at infinity is a no-op if z1 == 0: return p2 @@ -299,7 +301,8 @@ def is_valid(self): return self.valid def get_bytes(self): - assert self.valid + if not self.valid: + raise ValueError("Invalid public key") p = SECP256K1.affine(self.p) if p is None: return None @@ -313,7 +316,8 @@ def verify_ecdsa(self, sig, msg, low_s=True): See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the ECDSA verifier algorithm""" - assert self.valid + if not self.valid: + raise ValueError("Invalid public key") # Extract r and s from the DER formatted signature. Return false for # any DER encoding errors. @@ -378,7 +382,8 @@ def __init__(self): def set(self, secret, compressed): """Construct a private key object with given 32-byte secret and compressed flag.""" - assert len(secret) == 32 + if len(secret) != 32: + raise ValueError("Invalid secret key length") secret = int.from_bytes(secret, "big") self.valid = secret > 0 and secret < SECP256K1_ORDER if self.valid: @@ -391,7 +396,8 @@ def generate(self, compressed=True): def get_bytes(self): """Retrieve the 32-byte representation of this key.""" - assert self.valid + if not self.valid: + raise ValueError("Invalid private key") return self.secret.to_bytes(32, "big") @property @@ -404,7 +410,8 @@ def is_compressed(self): def get_pubkey(self): """Compute an ECPubKey object for this secret key.""" - assert self.valid + if not self.valid: + raise ValueError("Invalid private key") ret = ECPubKey() p = SECP256K1.mul([(SECP256K1_G, self.secret)]) ret.p = p @@ -417,7 +424,8 @@ def sign_ecdsa(self, msg, nonce_function=None, extra_data=None, low_s=True): See https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm for the ECDSA signer algorithm.""" - assert self.valid + if not self.valid: + raise ValueError("Invalid private key") z = int.from_bytes(msg, "big") if nonce_function is None: nonce_function = deterministic_k @@ -470,7 +478,8 @@ def compute_xonly_pubkey(key): This also returns whether the resulting public key was negated. """ - assert len(key) == 32 + if len(key) != 32: + raise ValueError("Invalid private key length") x = int.from_bytes(key, "big") if x == 0 or x >= SECP256K1_ORDER: return (None, None) @@ -481,8 +490,10 @@ def compute_xonly_pubkey(key): def tweak_add_privkey(key, tweak): """Tweak a private key (after negating it if needed).""" - assert len(key) == 32 - assert len(tweak) == 32 + if len(key) != 32: + raise ValueError("Invalid private key length") + if len(tweak) != 32: + raise ValueError("Invalid tweak length") x = int.from_bytes(key, "big") if x == 0 or x >= SECP256K1_ORDER: @@ -501,8 +512,10 @@ def tweak_add_privkey(key, tweak): def tweak_add_pubkey(key, tweak): """Tweak a public key and return whether the result had to be negated.""" - assert len(key) == 32 - assert len(tweak) == 32 + if len(key) != 32: + raise ValueError("Invalid public key length") + if len(tweak) != 32: + raise ValueError("Invalid tweak length") x_coord = int.from_bytes(key, "big") if x_coord >= SECP256K1_FIELD_SIZE: @@ -525,9 +538,12 @@ def verify_schnorr(key, sig, msg): - sig is a 64-byte Schnorr signature - msg is a 32-byte message """ - assert len(key) == 32 - assert len(msg) == 32 - assert len(sig) == 64 + if len(key) != 32: + raise ValueError("Invalid public key length") + if len(msg) != 32: + raise ValueError("Invalid message length") + if len(sig) != 64: + raise ValueError("Invalid signature length") x_coord = int.from_bytes(key, "big") if x_coord == 0 or x_coord >= SECP256K1_FIELD_SIZE: @@ -556,10 +572,13 @@ def verify_schnorr(key, sig, msg): def sign_schnorr(key, msg, aux=None, flip_p=False, flip_r=False): """Create a Schnorr signature (see BIP 340).""" - assert len(key) == 32 - assert len(msg) == 32 + if len(key) != 32: + raise ValueError("Invalid private key length") + if len(msg) != 32: + raise ValueError("Invalid message length") if aux is not None: - assert len(aux) == 32 + if len(aux) != 32: + raise ValueError("Invalid aux length") sec = int.from_bytes(key, "big") if sec == 0 or sec >= SECP256K1_ORDER: @@ -579,7 +598,8 @@ def sign_schnorr(key, msg, aux=None, flip_p=False, flip_r=False): ) % SECP256K1_ORDER ) - assert kp != 0 + if kp == 0: + raise ValueError("k is zero") R = SECP256K1.affine(SECP256K1.mul([(SECP256K1_G, kp)])) k = kp if SECP256K1.has_even_y(R) != flip_r else SECP256K1_ORDER - kp e = ( diff --git a/src/embit/util/py_ripemd160.py b/src/embit/util/py_ripemd160.py index d42c4a3..d995b26 100644 --- a/src/embit/util/py_ripemd160.py +++ b/src/embit/util/py_ripemd160.py @@ -359,7 +359,7 @@ def fi(x, y, z, i): elif i == 4: return x ^ (y | ~z) else: - assert False + raise ValueError("Invalid function index") def rol(x, i): diff --git a/src/embit/util/py_secp256k1.py b/src/embit/util/py_secp256k1.py index 8514086..0f1f8e1 100644 --- a/src/embit/util/py_secp256k1.py +++ b/src/embit/util/py_secp256k1.py @@ -283,9 +283,12 @@ def xonly_pubkey_from_pubkey(pubkey, context=None): def schnorrsig_verify(sig, msg, pubkey, context=None): - assert len(sig) == 64 - assert len(msg) == 32 - assert len(pubkey) == 64 + if len(sig) != 64: + raise ValueError("Signature should be 64 bytes long") + if len(msg) != 32: + raise ValueError("Message should be 32 bytes long") + if len(pubkey) != 64: + raise ValueError("Public key should be 64 bytes long") sec = ec_pubkey_serialize(pubkey) return _key.verify_schnorr(sec[1:33], sig, msg) @@ -298,10 +301,12 @@ def keypair_create(secret, context=None): def schnorrsig_sign(msg, keypair, nonce_function=None, extra_data=None, context=None): - assert len(msg) == 32 + if len(msg) != 32: + raise ValueError("Message should be 32 bytes long") if len(keypair) == 32: keypair = keypair_create(keypair, context=context) - assert len(keypair) == 96 + if len(keypair) != 96: + raise ValueError("Keypair should be 96 bytes long") return _key.sign_schnorr(keypair[:32], msg, extra_data) From 84ecb0ca8406bf6a5ad654244beb6fd1ad7b3900 Mon Sep 17 00:00:00 2001 From: odudex Date: Tue, 29 Jul 2025 17:00:04 -0300 Subject: [PATCH 2/5] ValueError message adjust --- src/embit/bip85.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/embit/bip85.py b/src/embit/bip85.py index 6b74dc0..97013e5 100644 --- a/src/embit/bip85.py +++ b/src/embit/bip85.py @@ -52,7 +52,7 @@ def derive_xprv(root, index=0): def derive_hex(root, num_bytes=32, index=0): """Derive raw entropy from 16 to 64 bytes long""" if num_bytes > 64: - raise ValueError("Number of bytes must be less than 64") + raise ValueError("Number of bytes must not exceed 64") if num_bytes < 16: raise ValueError("Number of bytes must be at least 16") entropy = derive_entropy(root, 128169, [num_bytes, index]) From 89a941b092f3ed8685176323cfbd7b3a9c75b966 Mon Sep 17 00:00:00 2001 From: qlrd <106913782+qlrd@users.noreply.github.com> Date: Tue, 29 Jul 2025 20:04:42 +0000 Subject: [PATCH 3/5] test: add cases for `src/embit/bip85.py` (#1) In order to test the replace assertions with conditional error raising, specifically, in bip85 code, was added a test for `derive_entropy` method as well the failure cases described in `derive_mnemonic` and `derive_hex`. --- tests/tests/test_bip85.py | 44 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/tests/test_bip85.py b/tests/tests/test_bip85.py index e5ec9c4..a75fee3 100644 --- a/tests/tests/test_bip85.py +++ b/tests/tests/test_bip85.py @@ -48,12 +48,44 @@ class Bip85Test(TestCase): + + def test_derive_entropy(self): + for app_index, path, expected in [ + (39, [0, 12, 0], unhexlify("6250b68daf746d12a24d58b4787a714bf1b58d69e4c2a466276fb16fe93dc52b6fac6b756894072241447cad56f6405ee326dbb473d2f5e943543590082927c0")), + (2, [0], unhexlify("7040bb53104f27367f317558e78a994ada7296c6fde36a364e5baf206e502bb1f988080b7dd814e7ae7d6d83edbb6689886a560e165f4a740877cdf3beecacf8")), + (32, [0], unhexlify("52405cd0dd21c5be78314a7c1a3c65ffd8d896536cc7dee3157db5824f0c92e2ead0b33988a616cf6a497f1c169d9e92562604e38305ccd3fc96f2252c177682")), + ]: + result = bip85.derive_entropy(ROOT, app_index, path) + self.assertEqual(result, expected) + + def test_derive_entropy_fail_path_ge_hardened_index(self): + with self.assertRaises(ValueError) as exc: + bip85.derive_entropy(ROOT, 39, [bip32.HARDENED_INDEX + 1]) + self.assertEqual(str(exc.exception), "Path elements must be less than 2^31") + def test_bip39(self): for num_words, index, lang, expected in VECTORS_BIP39: self.assertEqual( bip85.derive_mnemonic(ROOT, num_words, index, language=lang), expected ) + def test_bip39_fail_num_words(self): + cases = [ + (11, 0, bip85.LANGUAGES.ENGLISH), + (13, 0, bip85.LANGUAGES.ENGLISH), + (15, 0, bip85.LANGUAGES.ENGLISH), + (17, 0, bip85.LANGUAGES.ENGLISH), + (19, 0, bip85.LANGUAGES.ENGLISH), + (21, 0, bip85.LANGUAGES.ENGLISH), + (23, 0, bip85.LANGUAGES.ENGLISH), + (25, 0, bip85.LANGUAGES.ENGLISH), + ] + + for num_words, index, lang in cases: + with self.assertRaises(ValueError) as exc: + bip85.derive_mnemonic(ROOT, num_words, index, language=lang) + self.assertEqual(str(exc.exception), "Number of words must be 12, 18 or 24") + def test_wif(self): for idx, expected in VECTORS_WIF: self.assertEqual(bip85.derive_wif(ROOT, idx).wif(), expected) @@ -67,3 +99,15 @@ def test_hex(self): self.assertEqual( bip85.derive_hex(ROOT, num_bytes, idx), unhexlify(expected) ) + + def test_hex_fail_num_bytes_ge_64(self): + for num_bytes in [65, 100, 1000, 10000]: + with self.assertRaises(ValueError) as exc: + bip85.derive_hex(ROOT, num_bytes, 1) + self.assertEqual(str(exc.exception), "Number of bytes must be less than 64") + + def test_hex_fail_num_bytes_le_16(self): + for num_bytes in [15, 14, 10, 0]: + with self.assertRaises(ValueError) as exc: + bip85.derive_hex(ROOT, num_bytes, 2) + self.assertEqual(str(exc.exception), "Number of bytes must be at least 16") From 1f1146a82483de85557740232892156cafae865f Mon Sep 17 00:00:00 2001 From: odudex Date: Wed, 30 Jul 2025 10:42:16 -0300 Subject: [PATCH 4/5] adjust test error message --- tests/tests/test_bip85.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests/test_bip85.py b/tests/tests/test_bip85.py index a75fee3..7b6280b 100644 --- a/tests/tests/test_bip85.py +++ b/tests/tests/test_bip85.py @@ -104,7 +104,7 @@ def test_hex_fail_num_bytes_ge_64(self): for num_bytes in [65, 100, 1000, 10000]: with self.assertRaises(ValueError) as exc: bip85.derive_hex(ROOT, num_bytes, 1) - self.assertEqual(str(exc.exception), "Number of bytes must be less than 64") + self.assertEqual(str(exc.exception), "Number of bytes must not exceed 64") def test_hex_fail_num_bytes_le_16(self): for num_bytes in [15, 14, 10, 0]: From 1ec692b3dabc5ff45691b587e3fa10183ccf01d2 Mon Sep 17 00:00:00 2001 From: odudex Date: Tue, 4 Jun 2024 12:15:26 -0300 Subject: [PATCH 5/5] refactor descriptor parsing --- src/embit/descriptor/descriptor.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/embit/descriptor/descriptor.py b/src/embit/descriptor/descriptor.py index 9f585dc..28ced7a 100644 --- a/src/embit/descriptor/descriptor.py +++ b/src/embit/descriptor/descriptor.py @@ -294,7 +294,7 @@ def from_string(cls, desc): @classmethod def read_from(cls, s): # starts with sh(wsh()), sh() or wsh() - start = s.read(7) + start = s.read(8) sh = False wsh = False wpkh = False @@ -303,30 +303,30 @@ def read_from(cls, s): taptree = TapTree() if start.startswith(b"tr("): taproot = True - s.seek(-4, 1) + s.seek(-5, 1) elif start.startswith(b"sh(wsh("): sh = True wsh = True + s.seek(-1, 1) elif start.startswith(b"wsh("): sh = False wsh = True - s.seek(-3, 1) - elif start.startswith(b"sh(wpkh"): + s.seek(-4, 1) + elif start.startswith(b"sh(wpkh("): is_miniscript = False sh = True wpkh = True - assert s.read(1) == b"(" elif start.startswith(b"wpkh("): is_miniscript = False wpkh = True - s.seek(-2, 1) + s.seek(-3, 1) elif start.startswith(b"pkh("): is_miniscript = False - s.seek(-3, 1) + s.seek(-4, 1) elif start.startswith(b"sh("): sh = True wsh = False - s.seek(-4, 1) + s.seek(-5, 1) else: raise ValueError("Invalid descriptor (starts with '%s')" % start.decode()) # taproot always has a key, and may have taptree miniscript