Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions src/embit/bip352.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
"""
BIP-352: Silent Payments
see: https://github.com/bitcoin/bips/blob/master/bip-0352.mediawiki

TODO:
* Implement signing SP spends (once psbt format is settled).
"""
from embit import bech32, ec
from embit.util import secp256k1
from embit.hashes import tagged_hash
from typing import Tuple, List, Dict
from embit.util.key import SECP256K1_ORDER
from embit.transaction import COutPoint
from embit.util.secp256k1 import (
ec_pubkey_create,
ec_pubkey_serialize,
ec_pubkey_parse,
ec_pubkey_tweak_mul,
ec_pubkey_tweak_add,
ec_seckey_verify,
ec_privkey_negate,
)
from embit.script import p2tr
from binascii import hexlify, unhexlify


def generate_silent_payment_address(
scan_privkey: ec.PrivateKey,
spend_pubkey: ec.PublicKey,
label: int | str | bytes | None = None,
network: str = "main",
version: int = 0,
) -> str:
"""
Adapted from https://github.com/bitcoin/bips/blob/master/bip-0352/reference.py

Generates the recipient's reusable silent payment address for a given:
* scan private key
* spend public key
* optional label for labeled addresses
"""
scan_pubkey = scan_privkey.get_public_key()
if label is not None:
if isinstance(label, int):
label = label.to_bytes(4, "big")
elif isinstance(label, str):
label = label.encode()
tweak = tagged_hash("BIP0352/Label", scan_privkey.secret + label)
spend_pubkey = ec.PublicKey(
secp256k1.ec_pubkey_add(
secp256k1.ec_pubkey_parse(spend_pubkey.sec()), tweak
)
)

data = bech32.convertbits(scan_pubkey.sec() + spend_pubkey.sec(), 8, 5)
hrp = "sp" if network == "main" else "tsp"
return bech32.bech32_encode(bech32.Encoding.BECH32M, hrp, [version] + data)


# TODO: use the bech32 decode function once the flexible bech32 PR is in
def decode_silent_payment_address(address: str) -> Tuple[ec.PublicKey, ec.PublicKey]:
"""
Decode a silent payment address and return the scan and spend public keys.
Silent payment addresses can be longer than 90 characters, so we need custom decoding.
"""
if address.startswith("sp1"):
hrp = "sp"
elif address.startswith("tsp1"):
hrp = "tsp"
else:
raise ValueError("Invalid silent payment address: unknown HRP")

# custom bech32 to bypass the 90-character limit
if (any(ord(x) < 33 or ord(x) > 126 for x in address)) or (
address.lower() != address and address.upper() != address
):
raise ValueError("Invalid silent payment address: invalid characters")

address = address.lower()
pos = address.rfind("1")
if pos < 1 or pos + 7 > len(address):
raise ValueError("Invalid silent payment address: invalid format")

if not all(x in bech32.CHARSET for x in address[pos + 1 :]):
raise ValueError(
"Invalid silent payment address: invalid characters in data part"
)

hrpgot = address[:pos]
data = [bech32.CHARSET.find(x) for x in address[pos + 1 :]]

if hrpgot != hrp:
raise ValueError("Invalid silent payment address: HRP mismatch")

encoding = bech32.bech32_verify_checksum(hrpgot, data)
if encoding is None:
raise ValueError("Invalid silent payment address: checksum verification failed")

if encoding != bech32.Encoding.BECH32M:
raise ValueError("Invalid silent payment address: must use bech32m encoding")

data = data[:-6]

if data[0] != 0:
raise ValueError(
f"Invalid silent payment address: unsupported version {data[0]}"
)

decoded = bech32.convertbits(data[1:], 5, 8, False)
if decoded is None:
raise ValueError("Invalid silent payment address: conversion failed")

try:
B_scan = ec.PublicKey.parse(bytes(decoded[:33]))
B_spend = ec.PublicKey.parse(bytes(decoded[33:]))
except Exception as e:
raise ValueError(f"Invalid silent payment address: invalid public keys - {e}")

return B_scan, B_spend


def get_input_hash(outpoints: List["COutPoint"], sum_pubkey_bytes: bytes) -> bytes:
lowest_outpoint = sorted(outpoints, key=lambda o: o.serialize())[0]
preimage = lowest_outpoint.serialize() + sum_pubkey_bytes
return tagged_hash("BIP0352/Inputs", preimage)


def create_outputs(
input_privkeys: List[Tuple[bytes, bool]],
outpoints: List["COutPoint"],
recipients: List[str],
) -> Dict[str, List[str]]:
"""
Creates silent payment outputs for given recipients.

Args:
input_privkeys: List of (private_key_bytes, is_xonly) tuples
outpoints: List of transaction outpoints
recipients: List of silent payment addresses (strings) - duplicates are allowed

Returns:
Dictionary mapping each unique recipient address to list of output hex strings
"""
if not input_privkeys:
return {}

signing_keys = []
for sec, is_xonly in input_privkeys:
if not ec_seckey_verify(sec):
raise ValueError("Invalid private key")

if is_xonly:
pub = ec_pubkey_create(sec)
ser = ec_pubkey_serialize(pub)
if ser[0] == 0x03:
sec = ec_privkey_negate(sec)
signing_keys.append(int.from_bytes(sec, "big"))

a_sum = sum(signing_keys) % SECP256K1_ORDER
if a_sum == 0:
return {}

a_sum_bytes = a_sum.to_bytes(32, "big")
A = ec_pubkey_create(a_sum_bytes)

input_hash = get_input_hash(outpoints, ec_pubkey_serialize(A))

from collections import Counter, defaultdict

recipient_counts = Counter(recipients)

groups: Dict[ec.PublicKey, List[Tuple[ec.PublicKey, str, int]]] = defaultdict(list)
for addr, count in recipient_counts.items():
B_scan, B_spend = decode_silent_payment_address(addr)
groups[B_scan].append((B_spend, addr, count))

result: Dict[str, List[str]] = {addr: [] for addr in recipient_counts.keys()}
scalar = (int.from_bytes(input_hash, "big") * a_sum) % SECP256K1_ORDER
scalar_bytes = scalar.to_bytes(32, "big")

for B_scan, B_spend_list in groups.items():
ecdh_point = ec_pubkey_parse(B_scan.sec())
ec_pubkey_tweak_mul(ecdh_point, scalar_bytes)
xonly_shared_secret = ec_pubkey_serialize(ecdh_point)

k = 0
for B_spend, addr, count in B_spend_list:
for _ in range(count):
t_k = tagged_hash(
"BIP0352/SharedSecret",
xonly_shared_secret + k.to_bytes(4, "big"),
)

P_k = ec_pubkey_parse(B_spend.sec())
ec_pubkey_tweak_add(P_k, t_k)

xonly = ec_pubkey_serialize(P_k)[1:33]
result[addr].append(hexlify(xonly).decode())
k += 1

return result


def generate_sp_destination_address(
input_privkeys: List[Tuple[bytes, bool]],
outpoints: List["COutPoint"],
recipient_sp_address: str,
) -> str:
outputs = create_outputs(input_privkeys, outpoints, [recipient_sp_address])

dest_addr = []
for output in outputs:
pubkey = ec.PublicKey.parse(b"\x02" + unhexlify(output))
dest_addr.append(p2tr(pubkey).address())
return dest_addr
104 changes: 104 additions & 0 deletions src/embit/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import hashes
from . import compact
from .base import EmbitBase, EmbitError
from .util.key import ECPubKey, NUMS_H

SIGHASH_ALL = 1

Expand Down Expand Up @@ -60,6 +61,12 @@ def script_type(self):
# unknown type
return None

def is_p2tr(self):
"""
Check if the script is a Pay-to-Taproot script.
"""
return self.script_type() == "p2tr"

def write_to(self, stream):
res = stream.write(compact.to_bytes(len(self.data)))
res += stream.write(self.data)
Expand Down Expand Up @@ -210,3 +217,100 @@ def script_sig_p2sh(redeem_script):

def witness_p2wpkh(signature, pubkey, sighash=SIGHASH_ALL):
return Witness([signature.serialize() + bytes([sighash]), pubkey.sec()])


# New helper: extract/validate input pubkey for various script types


def get_input_pubkey(
prevout_script, script_sig: bytes | str | None = None, witness=None
) -> ECPubKey:
"""Extract and validate a public key for an input based on its prevout script type.

- prevout_script: Script or raw bytes of the prevout scriptPubKey
- script_sig: raw bytes or hex string of the scriptSig (if legacy)
- witness: Witness object or a list of bytes for segwit/taproot

Returns ECPubKey() with .valid=False if no suitable compressed pubkey can be determined.
"""
if isinstance(prevout_script, bytes):
spk = Script(prevout_script)
else:
spk = prevout_script

if isinstance(script_sig, str):
try:
ss = bytes.fromhex(script_sig)
except Exception:
ss = b""
elif isinstance(script_sig, bytes):
ss = script_sig
else:
ss = b""

if isinstance(witness, Witness):
wstack = witness.items
elif isinstance(witness, list):
wstack = witness
else:
wstack = []

script_type = spk.script_type()

if script_type == "p2pkh":
spk_hash = spk.data[3:23]
for i in range(len(ss), 32, -1):
if i >= 33:
pubkey_bytes = ss[i - 33 : i]
if len(pubkey_bytes) == 33 and pubkey_bytes[0] in (0x02, 0x03):
if hashes.hash160(pubkey_bytes) == spk_hash:
pubkey = ECPubKey().set(pubkey_bytes)
if pubkey.valid and pubkey.is_compressed:
return pubkey
return ECPubKey()

if script_type == "p2sh":
if len(ss) > 1 and wstack:
pubkey_bytes = wstack[-1]
pubkey = ECPubKey().set(pubkey_bytes)
if pubkey.valid and pubkey.is_compressed:
return pubkey
return ECPubKey()

if script_type == "p2wpkh":
if wstack:
pubkey_bytes = wstack[-1]
pubkey = ECPubKey().set(pubkey_bytes)
if pubkey.valid and pubkey.is_compressed:
return pubkey
return ECPubKey()

if script_type == "p2tr":
if wstack:
# strip annex if present (last element starting with 0x50)
if len(wstack) > 1 and wstack[-1][:1] == b"\x50":
wstack = wstack[:-1]
# Script-path spend: if control block signals NUMS, skip
if len(wstack) > 1:
control_block = wstack[-1]
if len(control_block) >= 33 and control_block[1:33] == NUMS_H.to_bytes(
32, "big"
):
return ECPubKey()
# Key-path spend: reconstruct even-y compressed SEC from x-only
if len(spk.data) >= 34:
xonly = spk.data[2:34]
pubkey_bytes = b"\x02" + xonly
pubkey = ECPubKey().set(pubkey_bytes)
if pubkey.valid:
return pubkey
else:
if len(spk.data) >= 34:
xonly = spk.data[2:34]
pubkey_bytes = b"\x02" + xonly
pubkey = ECPubKey().set(pubkey_bytes)
if pubkey.valid:
return pubkey
return ECPubKey()

return ECPubKey()
36 changes: 36 additions & 0 deletions src/embit/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .base import EmbitBase, EmbitError
from .script import Script, Witness
from .misc import const
from binascii import unhexlify
from typing import NamedTuple


class TransactionError(EmbitError):
Expand Down Expand Up @@ -398,3 +400,37 @@ def read_from(cls, stream):
value = int.from_bytes(stream.read(8), "little")
script_pubkey = Script.read_from(stream)
return cls(value, script_pubkey)


class COutPoint(NamedTuple):
txid: bytes # endianness same as hex string displayed; reverse of tx serialization order
out_idx: int

@classmethod
def from_str(cls, s: str) -> "COutPoint":
hash_str, idx_str = s.split(":")
assert len(hash_str) == 64, f"{hash_str} should be a sha256 hash"
return COutPoint(txid=unhexlify(hash_str), out_idx=int(idx_str))

def __str__(self) -> str:
return f"""COutPoint("{self.to_str()}")"""

def __repr__(self):
return f"<{str(self)}>"

def to_str(self) -> str:
return f"{self.txid.hex()}:{self.out_idx}"

def to_json(self):
return [self.txid.hex(), self.out_idx]

def serialize(self) -> bytes:
return self.txid[::-1] + int.to_bytes(
self.out_idx, length=4, byteorder="little", signed=False
)

def is_coinbase(self) -> bool:
return self.txid == bytes(32)

def short_name(self):
return f"{self.txid.hex()[0:10]}:{self.out_idx}"
Loading