From 818cfc1f08a680ba2b540c0cef28563cd9d942cc Mon Sep 17 00:00:00 2001 From: antazoey Date: Fri, 20 Feb 2026 09:28:12 -0600 Subject: [PATCH 1/3] more fixes --- tests/model/test_asset_identifier.py | 8 + tests/model/test_settlement.py | 13 +- tplus/client/clearingengine/assetregistry.py | 20 ++ tplus/client/clearingengine/decimal.py | 21 +- tplus/client/clearingengine/vault.py | 7 + tplus/evm/contracts.py | 108 +++++++-- tplus/evm/managers/chaindata.py | 17 +- tplus/evm/managers/credential_manager.py | 4 +- tplus/evm/managers/deposit.py | 16 +- tplus/evm/managers/registry.py | 41 ++++ tplus/evm/managers/settle.py | 238 ++++++++++++------- tplus/model/approval.py | 12 + tplus/model/asset_identifier.py | 12 + tplus/model/chain_address.py | 10 +- tplus/model/settlement.py | 32 ++- tplus/utils/amount.py | 2 +- tplus/utils/user/model.py | 12 +- 17 files changed, 434 insertions(+), 139 deletions(-) create mode 100644 tplus/model/approval.py diff --git a/tests/model/test_asset_identifier.py b/tests/model/test_asset_identifier.py index 8677d06..16aeef7 100644 --- a/tests/model/test_asset_identifier.py +++ b/tests/model/test_asset_identifier.py @@ -1,6 +1,7 @@ import pytest from tplus.model.asset_identifier import AssetIdentifier +from tplus.model.chain_address import ChainAddress class TestAssetIdentifier: @@ -130,3 +131,10 @@ def test_chain_id_too_small(self): raw_str = "62622E77D1349Face943C6e7D5c01C61465FE1dc@a4b1" with pytest.raises(ValueError): _ = AssetIdentifier(raw_str) + + def test_validate_from_chain_address(self): + address = ChainAddress.from_str( + "62622E77D1349Face943C6e7D5c01C61465FE1dc@000000000000aa36a7" + ) + asset = AssetIdentifier.model_validate(address) + assert asset == address diff --git a/tests/model/test_settlement.py b/tests/model/test_settlement.py index 4272725..a7532ac 100644 --- a/tests/model/test_settlement.py +++ b/tests/model/test_settlement.py @@ -15,7 +15,15 @@ class TestInnerSettlementRequest: def test_from_raw(self, user): request = InnerSettlementRequest.from_raw( - ASSET_IN, 100, 6, ASSET_OUT, 100, 18, user.public_key, CHAIN_ID + ASSET_IN, + 100, + 6, + ASSET_OUT, + 100, + 18, + user.public_key, + CHAIN_ID, + 0, ) assert ( request.asset_in.root @@ -35,6 +43,7 @@ class TestTxSettlementRequest: def settlement(self, user): return { "tplus_user": user.public_key, + "sub_account_index": 0, "settler": user.public_key, **get_base_settlement_data(), "chain_id": CHAIN_ID, @@ -46,7 +55,7 @@ def test_signing_payload(self, settlement, user): """ settlement = TxSettlementRequest(inner=settlement, signature=[]) actual = settlement.signing_payload() - expected = f'{{"tplus_user":"{user.public_key}","settler":"{user.public_key}","asset_in":"62622e77d1349face943c6e7d5c01c61465fe1dc000000000000000000000000@000000000000aa36a7","amount_in":"9f4cfc56cd29b000","asset_out":"58372ab62269a52fa636ad7f200d93999595dcaf000000000000000000000000@000000000000aa36a7","amount_out":"8e1bc9bf04000","chain_id":"000000000000aa36a7"}}' + expected = f'{{"tplus_user":"{user.public_key}","sub_account_index":0,"settler":"{user.public_key}","asset_in":"62622e77d1349face943c6e7d5c01c61465fe1dc000000000000000000000000@000000000000aa36a7","amount_in":"9f4cfc56cd29b000","asset_out":"58372ab62269a52fa636ad7f200d93999595dcaf000000000000000000000000@000000000000aa36a7","amount_out":"8e1bc9bf04000","chain_id":"000000000000aa36a7"}}' assert actual == expected # Show it is the same as the inner version. diff --git a/tplus/client/clearingengine/assetregistry.py b/tplus/client/clearingengine/assetregistry.py index 2c12d7b..9c5984e 100644 --- a/tplus/client/clearingengine/assetregistry.py +++ b/tplus/client/clearingengine/assetregistry.py @@ -42,3 +42,23 @@ async def update_risk_parameters(self): Request that the clearing engine updates its registered risk parameters. """ await self._post("params/update") + + async def set_registry_address(self, registry_address: ChainAddress): + """ + Admin-only endpoint for setting the registry address. Used in testing environment. + """ + payload = registry_address.model_dump() + await self._post("admin/registry/update-address", json_data=payload) + + async def update_fee_account(self): + """ + Request that the clearing engine update its fee account. + """ + await self._post("fee-account/update") + + async def get_fee_account(self) -> str: + """ + Get the fee account. + """ + account = await self._get("fee-account") + return f"{account}" diff --git a/tplus/client/clearingengine/decimal.py b/tplus/client/clearingengine/decimal.py index ace65a4..34e69af 100644 --- a/tplus/client/clearingengine/decimal.py +++ b/tplus/client/clearingengine/decimal.py @@ -2,16 +2,25 @@ from tplus.client.clearingengine.base import BaseClearingEngineClient from tplus.model.asset_identifier import AssetIdentifier +from tplus.model.chain_address import ChainAddress +from tplus.model.types import ChainID def _prep_request( - asset_ids: Sequence[str | AssetIdentifier] | str | AssetIdentifier, chains: list[str] | str + asset_ids: Sequence[str | AssetIdentifier] | str | AssetIdentifier, chains: list[ChainID] | str ) -> dict: - asset_ids_seq = [asset_ids] if isinstance(asset_ids, str) else asset_ids - chains = [chains] if isinstance(chains, str) else chains + asset_ids_seq = ( + [asset_ids] + if isinstance(asset_ids, str) or isinstance(asset_ids, AssetIdentifier) + else asset_ids + ) + + # type ignore to avoid type-related bugs when using str instead of ChainID. + chains = [chains] if isinstance(chains, str) else chains # type: ignore + assets = [] for asset in asset_ids_seq: - if isinstance(asset, AssetIdentifier): + if isinstance(asset, ChainAddress): # Already validated. assets.append(asset.model_dump()) @@ -34,7 +43,7 @@ class DecimalClient(BaseClearingEngineClient): APIs related to decimals. """ - async def get(self, asset_id: list[str | AssetIdentifier], chains: list[str]) -> dict: + async def get(self, asset_id: list[str | AssetIdentifier], chains: list[ChainID]) -> dict: """ Get CE cached decimals for the given assets and chains. @@ -48,7 +57,7 @@ async def get(self, asset_id: list[str | AssetIdentifier], chains: list[str]) -> request = _prep_request(asset_id, chains) return await self._get("decimals", json_data=request) - async def update(self, asset_id: list[str | AssetIdentifier], chains: list[str]): + async def update(self, asset_id: list[str | AssetIdentifier], chains: list[ChainID]): """ Request that the CE update cache decimals for the given assets and chains. """ diff --git a/tplus/client/clearingengine/vault.py b/tplus/client/clearingengine/vault.py index 7acbd41..9e9be63 100644 --- a/tplus/client/clearingengine/vault.py +++ b/tplus/client/clearingengine/vault.py @@ -32,3 +32,10 @@ async def get(self) -> list[ChainAddress]: """ result: list = await self._get("vaults") or [] # type: ignore return [ChainAddress.model_validate(a) for a in result] + + async def set_credential_manager_address(self, credential_manager: ChainAddress): + """ + Admin-only endpoint for setting the credential manager address. Used in testing environment. + """ + payload = credential_manager.model_dump() + await self._post("admin/credential-manager/update-address", json_data=payload) diff --git a/tplus/evm/contracts.py b/tplus/evm/contracts.py index 0c4cc0a..c2a6a5f 100644 --- a/tplus/evm/contracts.py +++ b/tplus/evm/contracts.py @@ -27,6 +27,8 @@ from ape.contracts.base import ContractContainer, ContractInstance from ape.managers.project import LocalProject + from tplus.utils.user import User + CHAIN_MAP = { 1: "ethereum:mainnet", 11155111: "ethereum:sepolia", @@ -109,6 +111,9 @@ def load_tplus_contracts_project(version: str | None = None) -> "LocalProject": # Working from the t+ contracts repo return ManagerAccessMixin.local_project + elif path := os.environ.get("TPLUS_CONTRACTS_PATH"): + return Project(path) + # Load the project from dependencies. try: project = _load_tplus_contracts_from_dependencies(version=version) @@ -193,7 +198,7 @@ def __init__( address: str | None = None, tplus_contracts_version: str | None = None, ) -> None: - self._deployments: dict[int, ContractInstance] = {} + self._deployments: dict[str, ContractInstance] = {} self._default_deployer = default_deployer if isinstance(chain_id, int): @@ -202,9 +207,12 @@ def __init__( self._chain_id = chain_id self._address = address self._tplus_contracts_version = tplus_contracts_version + self._attempted_deploy_dev = False if address is not None and chain_id is not None: - self._deployments[f"{chain_id}"] = self._contract_container.at(address) + self._deployments[f"{chain_id}"] = self._contract_container.at( + address, detect_proxy=False, fetch_from_explorer=False + ) @classmethod def at(cls, address: str) -> "TPlusContract": @@ -233,6 +241,13 @@ def deploy_dev(cls, **kwargs): owner = kwargs.get("sender") or get_dev_default_owner() return cls.deploy(owner, sender=owner) + def deploy_dev_and_set_deployment(self) -> "TPlusContract": + self._attempted_deploy_dev = True + instance = self.deploy_dev() + self._address = instance.address + self._deployments[f"{instance.chain_id}"] = instance + return instance + def __repr__(self) -> str: return f"<{self.name}>" @@ -241,14 +256,18 @@ def __getattr__(self, attr_name: str): # First, try a regular attribute on the class return self.__getattribute__(attr_name) except AttributeError: - # Resort to something defined on the contract. + if attr_name.startswith("_"): + # Ignore internals, causes integration issues. + raise + + # Try something defined on the contract. return getattr(self.contract, attr_name) @property def name(self) -> str: return self.__class__.NAME - @property + @cached_property def chain_id(self) -> ChainID: return self._chain_id or ChainID.evm(self.chain_manager.chain_id) @@ -257,8 +276,7 @@ def address(self) -> str: if address := self._address: return address - chain_id = self._chain_id or ChainID.evm(self.chain_manager.chain_id) - return self.get_address(chain_id=chain_id) + return self.get_address(chain_id=self.chain_id) @property def tplus_contracts_project(self) -> "Project": @@ -277,15 +295,16 @@ def contract(self) -> "ContractInstance": try: return self.get_contract() except ContractNotExists: - if self.chain_manager.provider.network.is_local: + if self.is_local_network and not self._attempted_deploy_dev: # If simulating, deploy it now. - instance = self.deploy_dev() - self._address = instance.address - self._deployments[self.chain_manager.chain_id] = instance - return instance + return self.deploy_dev_and_set_deployment() raise # This error. + @property + def is_local_network(self) -> bool: + return self.chain_manager.provider.network.is_local + @property def default_deployer(self) -> AccountAPI: if deployer := self._default_deployer: @@ -322,13 +341,15 @@ def get_contract(self, chain_id: ChainID | None = None) -> "ContractInstance": Returns: ContractInstance """ - chain_id = chain_id or self._chain_id or ChainID.evm(self.chain_manager.chain_id) + chain_id = chain_id or self.chain_id if chain_id in self._deployments: # Get previously cached instance. return self._deployments[chain_id] address = self.get_address(chain_id=chain_id) - contract_container = self._contract_container.at(address) + contract_container = self._contract_container.at( + address, detect_proxy=False, fetch_from_explorer=False + ) # Cache for next time. self._deployments[chain_id] = contract_container @@ -339,11 +360,18 @@ def get_address(self, chain_id: ChainID | None = None) -> str: if self._address and self._chain_id and chain_id == self._chain_id: return self._address - chain_id = chain_id or self._chain_id or ChainID.evm(self.chain_manager.chain_id) + chain_id = chain_id or self.chain_id try: return TPLUS_DEPLOYMENTS[chain_id][self.name] - except KeyError: - raise ContractNotExists(f"{self.name} not deployed on chain '{chain_id}'.") + except KeyError as err: + if self.is_local_network and not self._attempted_deploy_dev: + try: + return self.deploy_dev_and_set_deployment().address + except Exception: + # Raise `ContractNotExists` below. + pass + + raise ContractNotExists(f"{self.name} not deployed on chain '{chain_id}'.") from err class Registry(TPlusContract): @@ -383,17 +411,25 @@ def set_asset( self, index: int, asset_address: HexBytes32 | AddressType, - chain_id: int, + chain_id: ChainID, max_deposit: int, + max_1hr_deposits: int, + min_weight: int, sender=None, ) -> None: if isinstance(asset_address, str) and len(asset_address) <= 42: # Given EVM style address. Store as right-padded address. asset_address = to_bytes32(asset_address, pad="r") - return self.contract.setAssetData( - index, (asset_address, chain_id, max_deposit), sender=sender - ) + data = { + "index": index, + "assetAddress": asset_address, + "chainId": {"routingId": chain_id.routing_id, "vmId": chain_id.vm_id}, + "maxDeposits": max_deposit, + "max1hrDeposits": max_1hr_deposits, + "minWeight": min_weight, + } + return self.contract.setAssetData(data, sender=sender) class DepositVault(TPlusContract): @@ -435,6 +471,24 @@ def add_settler_executor( ) -> "ReceiptAPI": return self.addSettlerExecutor(settler, executor, **kwargs) + def get_settlement_count(self, user: "UserPublicKey | User", account_index: int) -> int: + if not isinstance(user, UserPublicKey): + user = user.public_key + + return self.contract.settlementCounts(user, account_index) + + def get_deposit_count(self, user: "UserPublicKey | User", account_index: int) -> int: + if not isinstance(user, UserPublicKey): + user = user.public_key + + return self.contract.depositCounts(user, account_index) + + def get_withdrawal_count(self, user: "UserPublicKey | User", account_index: int) -> int: + if not isinstance(user, UserPublicKey): + user = user.public_key + + return self.contract.withdrawalCounts(user, account_index) + def deposit( self, user: UserPublicKey, @@ -455,6 +509,7 @@ def execute_atomic_settlement( self, settlement: dict, user: UserPublicKey, + sub_account: int, expiry: int, data: HexBytes, signature: HexBytes, @@ -462,11 +517,11 @@ def execute_atomic_settlement( ) -> "ReceiptAPI": try: return self.contract.executeAtomicSettlement( - settlement, user, expiry, data, signature, **tx_kwargs + settlement, user, sub_account, expiry, data, signature, **tx_kwargs ) except Exception as err: err_id = getattr(err, "message", "") - if erc20_err_name := _decode_erc20_error(err.message): + if erc20_err_name := _decode_erc20_error(getattr(err, "message", f"{err}")): raise ContractLogicError(erc20_err_name) from err elif err_id == "0x203d82d8": @@ -500,7 +555,9 @@ def deploy_dev(cls, sender: AccountAPI | None = None, **kwargs) -> TPlusContract """ Deploy and set up a development vault. """ - credman = kwargs.get("credential_manager") or ZERO_ADDRESS + if not (credman := kwargs.get("credential_manager")): + credman = credential_manager.address + sender = sender or cls.account_manager.test_accounts[0] contract = cast(DepositVault, cls.deploy(sender, credman, sender=sender)) @@ -553,7 +610,10 @@ def deploy_dev(cls, **kwargs) -> "ReceiptAPI": owner = kwargs.get("sender") or get_dev_default_owner() operators = kwargs.get("operators", [owner.address]) threshold = kwargs.get("quorum_threshold") or len(operators) - registry_address = kwargs.get("registry") or ZERO_ADDRESS + + if not (registry_address := kwargs.get("registry")): + registry_address = registry.address + measurements = kwargs.get("measurements") or [] automata_verifier = kwargs.get("automata_verifier") or ZERO_ADDRESS diff --git a/tplus/evm/managers/chaindata.py b/tplus/evm/managers/chaindata.py index 074f632..324af8e 100644 --- a/tplus/evm/managers/chaindata.py +++ b/tplus/evm/managers/chaindata.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from tplus.model.asset_identifier import AssetIdentifier + from tplus.model.types import UserPublicKey from tplus.utils.user import User @@ -18,13 +19,13 @@ class ChainDataFetcher(ChainConnectedManager): def __init__( self, - tplus_user: "User", + default_user: "User", clearing_engine: ClearingEngineClient | None = None, chain_id: ChainID | None = None, ): - self.tplus_user = tplus_user + self.default_user = default_user self.ce: ClearingEngineClient = clearing_engine or ClearingEngineClient( - self.tplus_user, "http://127.0.0.1:3032" + self.default_user, "http://127.0.0.1:3032" ) self.chain_id = chain_id or ChainID.evm(self.chain_manager.chain_id) @@ -78,11 +79,13 @@ async def sync_vaults(self): async def sync_assets(self): await self.ce.assets.update() - async def sync_deposits(self): - await self.ce.deposits.update_nonce(self.tplus_user.public_key, self.chain_id) + async def sync_deposits(self, user: "UserPublicKey | None" = None) -> None: + user = user or self.default_user.public_key + await self.ce.deposits.update_nonce(user, self.chain_id) - async def sync_settlements(self): - await self.ce.settlements.update_nonce(self.tplus_user.public_key, self.chain_id) + async def sync_settlements(self, user: "UserPublicKey | None" = None): + user = user or self.default_user.public_key + await self.ce.settlements.update_nonce(user, self.chain_id) async def update_decimals(self, assets: Sequence["AssetIdentifier"]): await self.ce.decimals.update( diff --git a/tplus/evm/managers/credential_manager.py b/tplus/evm/managers/credential_manager.py index 2022042..f89282b 100644 --- a/tplus/evm/managers/credential_manager.py +++ b/tplus/evm/managers/credential_manager.py @@ -130,10 +130,10 @@ async def add_vault( update_fn=lambda: ce.vaults.update(), get_fn=lambda: ce.vaults.get(), # cond: checks if the vault address is part any of the ChainAddress returned. - check_fn=lambda vaults: any(vault in vault_ca for vault_ca in vaults), + check_fn=lambda vaults: vault in vaults, timeout=10, interval=1, - error_msg="Vault registration failed.", + error_msg=f"Vault registration failed (vault={vault}).", ) return tx diff --git a/tplus/evm/managers/deposit.py b/tplus/evm/managers/deposit.py index 3c76c1b..476e40b 100644 --- a/tplus/evm/managers/deposit.py +++ b/tplus/evm/managers/deposit.py @@ -11,6 +11,7 @@ from ape.types.address import AddressType from tplus.client.clearingengine import ClearingEngineClient + from tplus.model.types import UserPublicKey from tplus.utils.user import User @@ -18,21 +19,26 @@ class DepositManager(ChainConnectedManager): def __init__( self, account: "AccountAPI", - tplus_user: "User", + default_user: "User", vault: DepositVault | None = None, chain_id: ChainID | None = None, clearing_engine: "ClearingEngineClient | None" = None, ): self.account = account - self.tplus_user = tplus_user + self.default_user = default_user self.chain_id = chain_id or ChainID.evm(self.chain_manager.chain_id) self.ce = clearing_engine self.vault = vault if vault else DepositVault(chain_id=self.chain_id) async def deposit( - self, token: "str | AddressType | ContractInstance", amount: int, wait: bool = False + self, + token: "str | AddressType | ContractInstance", + amount: int, + wait: bool = False, + user: "UserPublicKey | None" = None, ): - self.vault.deposit(self.tplus_user.public_key, token, amount, sender=self.account) + user = user or self.default_user.public_key + self.vault.deposit(user, token, amount, sender=self.account) if wait: if not (ce := self.ce): @@ -41,4 +47,4 @@ async def deposit( # There actually isn't a way to really wait for deposits since there isn't an API # to "get" them. Instead, just wait 3 seconds. await asyncio.sleep(3) - await ce.deposits.update_nonce(self.tplus_user.public_key, self.chain_id) + await ce.deposits.update_nonce(user, self.chain_id) diff --git a/tplus/evm/managers/registry.py b/tplus/evm/managers/registry.py index 20c6219..0a5623e 100644 --- a/tplus/evm/managers/registry.py +++ b/tplus/evm/managers/registry.py @@ -3,9 +3,12 @@ from tplus.evm.contracts import Registry from tplus.evm.managers.evm import ChainConnectedManager from tplus.model.types import ChainID +from tplus.utils.timeout import wait_for_condition if TYPE_CHECKING: from ape.api.accounts import AccountAPI + from ape.types.address import AddressType + from eth_pydantic_types.hex.bytes import HexBytes32 from tplus.client.clearingengine import ClearingEngineClient @@ -31,3 +34,41 @@ def __init__( self.registry = registry self.ce = clearing_engine + + async def set_asset( + self, + index: int, + asset_address: "HexBytes32 | AddressType", + chain_id: ChainID, + max_deposit: int, + max_1hr_deposits: int, + min_weight: int, + wait: bool = True, + **tx_kwargs, + ) -> None: + if "sender" not in tx_kwargs: + tx_kwargs["sender"] = self.owner + + self.registry.set_asset( + index, + asset_address, + chain_id, + max_deposit, + max_1hr_deposits, + min_weight, + **tx_kwargs, + ) + + if wait: + if not (ce := self.ce): + raise ValueError("Must have clearing_engine to wait for asset registration.") + + await wait_for_condition( + update_fn=lambda: ce.assets.update(), + get_fn=lambda: ce.assets.get(), + # cond: checks if the vault address is part any of the ChainAddress returned. + check_fn=lambda assets: f"{index}" in assets, + timeout=10, + interval=1, + error_msg=f"Asset registration failed (asset={index}).", + ) diff --git a/tplus/evm/managers/settle.py b/tplus/evm/managers/settle.py index 66a9cfd..26778bb 100644 --- a/tplus/evm/managers/settle.py +++ b/tplus/evm/managers/settle.py @@ -1,3 +1,4 @@ +import asyncio from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from functools import cached_property @@ -11,8 +12,9 @@ from tplus.evm.managers.deposit import DepositManager from tplus.evm.managers.evm import ChainConnectedManager from tplus.logger import get_logger +from tplus.model.approval import SettlementApproval from tplus.model.settlement import TxSettlementRequest -from tplus.model.types import ChainID +from tplus.model.types import ChainID, UserPublicKey from tplus.utils.amount import Amount from tplus.utils.user.decrypt import decrypt_settlement_approval @@ -37,7 +39,7 @@ class SettlementInfo: amount_in: Amount asset_out: "AssetIdentifier" amount_out: Amount - expected_nonce: int + nonce: int class SettlementManager(ChainConnectedManager): @@ -48,17 +50,17 @@ class SettlementManager(ChainConnectedManager): def __init__( self, - tplus_user: "User", + default_user: "User", ape_account: "AccountAPI", clearing_engine: ClearingEngineClient | None = None, chain_id: ChainID | None = None, vault: DepositVault | None = None, settlement_vault: DepositVault | None = None, ): - self.tplus_user = tplus_user + self.default_user = default_user self.ape_account = ape_account self.ce: ClearingEngineClient = clearing_engine or ClearingEngineClient( - self.tplus_user, "http://127.0.0.1:3032" + self.default_user, "http://127.0.0.1:3032" ) self.chain_id = chain_id or ChainID.evm(self.chain_manager.chain_id) self.vault = vault or DepositVault(chain_id=self.chain_id) @@ -69,11 +71,13 @@ def __init__( self.settlement_vault = settlement_vault or self.vault self.logger = get_logger() + self._approval_handling_tasks = {} + @cached_property def deposits(self) -> DepositManager: return DepositManager( self.ape_account, - self.tplus_user, + self.default_user, vault=self.vault, chain_id=self.chain_id, clearing_engine=self.ce, @@ -82,7 +86,7 @@ def deposits(self) -> DepositManager: @cached_property def chaindata(self) -> ChainDataFetcher: return ChainDataFetcher( - self.tplus_user, + self.default_user, self.ce, self.chain_id, ) @@ -108,16 +112,21 @@ async def prefetch_chaindata( settlements=settlements, ) - def _decrypt_settlement_approval_message(self, message: dict) -> dict | None: + def decrypt_settlement_approval_message( + self, message: dict, user: "User | None" = None + ) -> SettlementApproval | None: """ Decrypt and parse a settlement approval message from the WebSocket. Args: message: The raw message dictionary from the WebSocket containing encrypted_data. + user: Specify the tplus user. Defaults to the default_user Returns: - dict: The decrypted approval dictionary, or None if decryption/parsing fails. + SettlementApproval: The decrypted approval dictionary, or None if decryption/parsing fails. """ + user = user or self.default_user + try: encrypted_data = message["encrypted_data"] except KeyError as err: @@ -131,17 +140,23 @@ def _decrypt_settlement_approval_message(self, message: dict) -> dict | None: return None try: - return decrypt_settlement_approval(encrypted_data_bytes, self.tplus_user.sk) + data = decrypt_settlement_approval(encrypted_data_bytes, user.sk) except Exception as err: self.logger.warning(f"Failed to decrypt approval: {err}") return None + return SettlementApproval.model_validate(data) + async def init_settlement( self, asset_in: "AssetIdentifier", amount_in: Amount, asset_out: "AssetIdentifier", amount_out: Amount, + user: "User | None" = None, + account_index: int | None = None, + then_execute: bool = False, + on_approved: "Callable[[SettlementInfo, SettlementApproval], Awaitable[None] | None] | None" = None, ) -> SettlementInfo: """ Initialize a settlement asynchronously without waiting for approval. @@ -154,16 +169,28 @@ async def init_settlement( amount_in: Both the normalized and atomic amounts for the amount going into the protocol. asset_out: The ID of the asset leaving the protocol. amount_out: Both the normalized and atomic amounts for the amount leaving the protocol. + user: Specify the tplus user. Defaults to the default tplus user. + account_index: Specify the index of the tplus account for this settlement approval. Defaults to the + selected user's account index. + then_execute: Set to ``True`` to wait for the approval and then execute the settlement on-chain. + on_approved: Custom callback for receiving the approval from the CE. Returns: SettlementInfo: Information about the settlement including the expected nonce. """ + if on_approved and then_execute: + raise ValueError("Cannot provide both `on_approved` and `then_execute` arguments.") + # Get the expected nonce (current count before this settlement - it will increment after init) - expected_nonce = self.vault.settlementCounts(self.tplus_user.public_key) + user = user or self.default_user + expected_nonce = self.vault.settlementCounts(user.public_key, user.sub_account) amount_in_normalized = amount_in.to_inventory_amount("up") amount_out_normalized = amount_out.to_inventory_amount("down") + if account_index is None: + account_index = user.sub_account + request = TxSettlementRequest.create_signed( { "chain_id": self.chain_id, @@ -171,32 +198,76 @@ async def init_settlement( "amount_in": amount_in_normalized, "asset_out": asset_out, "amount_out": amount_out_normalized, + "sub_account_index": account_index, }, - self.tplus_user, + user, ) + await self._init_settlement(request) self.logger.info( - f"Initialized settlement - Asset in: {asset_in.evm_address}, " - f"Amount in: {amount_in.amount}, Asset out: {asset_out.evm_address}, " + f"Initialized settlement - Asset in: {asset_in}, " + f"Amount in: {amount_in.amount}, Asset out: {asset_out}, " f"Amount out: {amount_out.amount}, Expected nonce: {expected_nonce}" ) - return SettlementInfo( + settlement_info = SettlementInfo( asset_in=asset_in, amount_in=amount_in, asset_out=asset_out, amount_out=amount_out, - expected_nonce=expected_nonce, + nonce=expected_nonce, ) + if on_approved or then_execute: + handler = SettlementApprovalHandler(self) + + if then_execute: + + async def on_approved(info, approval): + await self.execute_settlement(info, approval) + + async def approval_handling_task_fn(): + try: + async with asyncio.timeout(12): + await handler.handle_approvals( + on_approval_received=on_approved, + stop_at=1, + pending_settlements={expected_nonce: settlement_info}, + ) + except TimeoutError: + self.logger.info("Approval handler timed out") + + approval_handling_task = asyncio.create_task(approval_handling_task_fn()) + + self._approval_handling_tasks.setdefault(user.public_key, {}) + + self._approval_handling_tasks[user.public_key][settlement_info.nonce] = ( + approval_handling_task + ) + + def _cleanup(_task: asyncio.Task): + tasks = self._approval_handling_tasks.get(user.public_key) + if not tasks: + return + + tasks.pop(settlement_info.nonce, None) + + if not tasks: + self._approval_handling_tasks.pop(user.public_key, None) + + approval_handling_task.add_done_callback(_cleanup) + + return settlement_info + async def _init_settlement(self, request: "TxSettlementRequest"): return await self.ce.settlements.init_settlement(request) async def execute_settlement( self, settlement_info: SettlementInfo, - approval: dict, + approval: SettlementApproval, + user: "UserPublicKey | None" = None, **kwargs, ) -> "ReceiptAPI": """ @@ -204,64 +275,57 @@ async def execute_settlement( Args: settlement_info: The settlement information from initialization. - approval: The decrypted approval dictionary from the clearing-engine. + approval: The decrypted approval from the clearing-engine. + user: Specify the tplus user. Defaults to the default tplus user. kwargs: Additional tx properties to pass to ``executeAtomicSettlement()`` e.g. ``gas=`` or ``required_confirmations=``. Returns: ReceiptAPI: The transaction receipt. """ - nonce = approval["inner"]["nonce"] - expiry = approval["expiry"] + nonce = approval.inner.nonce + expiry = approval.expiry + user = user or self.default_user + token_in_address = kwargs.pop("token_in", None) + token_out_address = kwargs.pop("token_out", None) # Validate that the approval matches the expected nonce - if nonce != settlement_info.expected_nonce: + if nonce != settlement_info.nonce: raise ValueError( - f"Approval nonce {nonce} does not match expected nonce {settlement_info.expected_nonce}" + f"Approval nonce {nonce} does not match expected nonce {settlement_info.nonce}" ) kwargs.setdefault("sender", self.ape_account) kwargs.setdefault("required_confirmations", 0) - self.logger.info( - "Executing settlement: " - f"Vault: {self.vault.address}, " - f"Chain ID: {self.chain_id}, " - f"User: {self.tplus_user.public_key}, " - f"Asset in: {settlement_info.asset_in.evm_address}, " - f"Amount in: {settlement_info.amount_in.amount}, " - f"Asset out: {settlement_info.asset_out.evm_address}, " - f"Amount out: {settlement_info.amount_out.amount}, " - f"Nonce: {nonce}, " - f"Expiry: {expiry}, " - f"Domain separator: {self.vault.domain_separator.hex()}" - ) + if token_in_address is None: + token_in_address = settlement_info.asset_in.evm_address + if token_out_address is None: + token_out_address = settlement_info.asset_out.evm_address # Execute the settlement on-chain. tx = self.settlement_vault.execute_atomic_settlement( { - "tokenIn": settlement_info.asset_in.evm_address, + "tokenIn": token_in_address, "amountIn": settlement_info.amount_in.amount, - "tokenOut": settlement_info.asset_out.evm_address, + "tokenOut": token_out_address, "amountOut": settlement_info.amount_out.amount, "nonce": nonce, }, - HexBytes(self.tplus_user.public_key), + HexBytes(user.public_key), + user.sub_account, expiry, "", - HexBytes(approval["inner"]["signature"]), + HexBytes(approval.inner.signature), **kwargs, ) return tx - async def get_approvals(self) -> list[dict]: - return await self.ce.settlements.get_signatures(self.tplus_user.public_key) - class SettlementApprovalHandler: """ - Handles settlement approval stream independently from settlement initialization. + Handles settlement approval stream independently of settlement initialization. Can be run in a separate async task to process approvals as they arrive. """ @@ -271,58 +335,66 @@ def __init__( ): self.settlement_manager = settlement_manager self.logger = settlement_manager.logger + self.on_approval_received = None async def handle_approvals( self, - pending_settlements: dict[int, SettlementInfo], on_approval_received: ( - "Callable[[SettlementInfo, dict], Awaitable[None] | None] | None" + "Callable[[SettlementInfo | None, SettlementApproval], Awaitable[None] | None] | None" ) = None, + pending_settlements: dict[int, SettlementInfo] | None = None, + stop_at: int | None = None, + user: "UserPublicKey | None" = None, ) -> None: """ Continuously listen for settlement approvals and match them with pending settlements. Args: - pending_settlements: Dictionary mapping nonce -> SettlementInfo for settlements - waiting for approval. Approved settlements will be removed from this dict. on_approval_received: Optional callback function that will be called when an approval is received. Called with (settlement_info, approval_dict). If not provided, approvals will just be logged. + pending_settlements: Dictionary mapping nonce -> SettlementInfo for settlements + waiting for approval. Approved settlements will be removed from this dict, if given. + To handle any settlement regardless, pass ``None`` or leave as default. + stop_at: The amount of approvals to handle before stopping. + user: Specify the user. Defaults to the default user. """ - self.logger.info( - f"Starting approval handler for user {self.settlement_manager.tplus_user.public_key}" - ) - - async for message in self.settlement_manager.ce.settlements.stream_approvals( - self.settlement_manager.tplus_user.public_key - ): - approval = self.settlement_manager._decrypt_settlement_approval_message(message) - if approval is None: - continue - - nonce = approval.get("inner", {}).get("nonce") - if nonce is None: - continue - - # Check if we have a pending settlement for this nonce - settlement_info = pending_settlements.get(nonce) - if settlement_info is None: - self.logger.debug(f"Received approval for unknown nonce {nonce}, ignoring") - continue + user = user or self.settlement_manager.default_user.public_key + self.logger.info(f"Starting approval handler for user {user}") + amount_handled = 0 - self.logger.info(f"Received approval for nonce {nonce}") - - # Remove from pending - del pending_settlements[nonce] - - # Call callback if provided - if on_approval_received: - try: - result = on_approval_received(settlement_info, approval) - if isinstance(result, Awaitable): - await result - except Exception as err: - self.logger.error( - f"Error in on_approval_received callback for nonce {nonce}: {err}", - exc_info=True, - ) + try: + async for message in self.settlement_manager.ce.settlements.stream_approvals(user): + approval = self.settlement_manager.decrypt_settlement_approval_message(message) + if approval is None: + continue + + nonce = approval.inner.nonce + pending_settlements = pending_settlements or {} + settlement_info = pending_settlements.get(nonce) + if settlement_info is None: + self.logger.debug(f"Received approval for unknown nonce {nonce}, ignoring") + continue + + else: + self.logger.info(f"Received approval for nonce {nonce}") + del pending_settlements[nonce] + + callback = on_approval_received or self.on_approval_received + if callback: + try: + result = callback(settlement_info, approval) + if asyncio.iscoroutine(result): + await result + except Exception as err: + self.logger.error( + f"Error in on_approval_received callback for nonce {nonce}: {err}", + exc_info=True, + ) + + amount_handled += 1 + if stop_at is not None and amount_handled >= stop_at: + break + + except asyncio.TimeoutError: + self.logger.info("Approval handler timed out") diff --git a/tplus/model/approval.py b/tplus/model/approval.py new file mode 100644 index 0000000..f16a619 --- /dev/null +++ b/tplus/model/approval.py @@ -0,0 +1,12 @@ +from eth_pydantic_types import HexBytes +from pydantic import BaseModel + + +class InnerSettlementApproval(BaseModel): + nonce: int + signature: HexBytes + + +class SettlementApproval(BaseModel): + inner: InnerSettlementApproval + expiry: int diff --git a/tplus/model/asset_identifier.py b/tplus/model/asset_identifier.py index d5f52c3..1efb4f3 100644 --- a/tplus/model/asset_identifier.py +++ b/tplus/model/asset_identifier.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import Any from pydantic import model_serializer, model_validator @@ -28,6 +29,17 @@ def _validate_input(cls, data: Any) -> Any: def __str__(self) -> str: return str(self.root) + @cached_property + def indexed(self): + return "@" not in self.root + @model_serializer def serialize_model(self) -> str: return self.root + + @property + def evm_address(self) -> str: + if self.indexed: + raise ValueError("Indexed asset identifiers do not have an address.") + + return super().evm_address diff --git a/tplus/model/chain_address.py b/tplus/model/chain_address.py index 66c171f..1cbad8f 100644 --- a/tplus/model/chain_address.py +++ b/tplus/model/chain_address.py @@ -60,7 +60,7 @@ def validate_chain_address(chain_address: str) -> str: # Case 1: Already validated. if isinstance(chain_address, ChainAddress): # Already validated. - return chain_address + return f"{chain_address}" # Case 2: Input is a dictionary from the backend (e.g., from JSON deserialization) elif isinstance(chain_address, dict): @@ -111,6 +111,9 @@ def __contains__(self, key: Any) -> bool: return False + def __eq__(self, other: Any) -> bool: + return f"{other}" == f"{self}" + def _contains_address_str(self, value: str) -> bool: key_str = value.removeprefix("0x") if len(key_str) % 2 != 0: @@ -147,7 +150,10 @@ def evm_address(self) -> str: except ImportError: return address # Non-checksummed. - return to_checksum_address(address) + try: + return to_checksum_address(address) + except Exception as err: + raise ValueError(f"Invalid address '{address}'") from err @cached_property def chain_id(self) -> ChainID: diff --git a/tplus/model/settlement.py b/tplus/model/settlement.py index fd934b7..c868e3c 100644 --- a/tplus/model/settlement.py +++ b/tplus/model/settlement.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, field_serializer from tplus.model.asset_identifier import AssetIdentifier +from tplus.model.chain_address import ChainAddress from tplus.model.types import ChainID, UserPublicKey from tplus.utils.decimals import to_inventory_decimals from tplus.utils.hex import str_to_vec @@ -35,6 +36,7 @@ class InnerSettlementRequest(BaseSettlement): """ tplus_user: UserPublicKey + sub_account_index: int settler: UserPublicKey chain_id: ChainID @@ -49,6 +51,7 @@ def from_raw( decimals_out: int, tplus_user: UserPublicKey, chain: ChainID | str, + sub_account_index: int, settler: UserPublicKey | None = None, ) -> "InnerSettlementRequest": """ @@ -69,14 +72,13 @@ def from_raw( settlement will occur. settler (:class:`~tplus.models.types.UserPublicKey`): The settler tplus account. If not provided, uses the same account as ``tplus_user``. + sub_account_index (int): The settler account index to pull funds from. Returns: InnerSettlementRequest: A normalized settlement request ready for processing. """ - if isinstance(asset_in, str) and "@" not in asset_in: - asset_in = AssetIdentifier(f"{asset_in}@{chain}") - if isinstance(asset_out, str) and "@" not in asset_out: - asset_out = AssetIdentifier(f"{asset_out}@{chain}") + asset_in = cls._validate_asset(asset_in, chain) + asset_out = cls._validate_asset(asset_out, chain) return cls.model_validate( { @@ -87,9 +89,19 @@ def from_raw( "tplus_user": tplus_user, "settler": settler or tplus_user, "chain_id": chain, + "sub_account_index": sub_account_index, } ) + @classmethod + def _validate_asset(cls, asset, chain: ChainID | str) -> AssetIdentifier: + if isinstance(asset, ChainAddress) and not isinstance(asset, AssetIdentifier): + return AssetIdentifier.model_validate(f"{asset}") + elif isinstance(asset, str) and "@" not in asset: + return AssetIdentifier(f"{asset}@{chain}") + + return asset + def signing_payload(self) -> str: base_data = self.model_dump(mode="json", exclude_none=True) @@ -100,12 +112,18 @@ def signing_payload(self) -> str: # NOTE: The order here matters! payload = { "tplus_user": user, + "sub_account_index": base_data.pop("sub_account_index"), "settler": settler, **base_data, "chain_id": chain_id, } - return json.dumps(payload, separators=(",", ":")) + return ( + json.dumps(payload, separators=(",", ":")) + .replace(" ", "") + .replace("\n", "") + .replace("\t", "") + ) class TxSettlementRequest(BaseModel): @@ -147,7 +165,9 @@ def create_signed( inner = InnerSettlementRequest.model_validate(inner) - signature = str_to_vec(signer.sign(inner.signing_payload()).hex()) + signing_payload = inner.signing_payload() + + signature = str_to_vec(signer.sign(signing_payload).hex()) return cls(inner=inner, signature=signature) def signing_payload(self) -> str: diff --git a/tplus/utils/amount.py b/tplus/utils/amount.py index a44736d..dfb9e4a 100644 --- a/tplus/utils/amount.py +++ b/tplus/utils/amount.py @@ -10,7 +10,7 @@ class Amount(BaseModel): amount: int """ - An amount normalized to clearing-engine decimals. + An atomic amount. """ decimals: int diff --git a/tplus/utils/user/model.py b/tplus/utils/user/model.py index 2e024f0..a64e37e 100644 --- a/tplus/utils/user/model.py +++ b/tplus/utils/user/model.py @@ -8,10 +8,15 @@ from tplus.utils.user.validate import privkey_to_bytes SEED_SIZE = 32 +MAIN_SUB_ACCOUNT = 0 class User: - def __init__(self, private_key: str | bytes | Ed25519PrivateKey | None = None): + def __init__( + self, + private_key: str | bytes | Ed25519PrivateKey | None = None, + sub_account: int | None = None, + ): if private_key: if isinstance(private_key, str | bytes): private_key_bytes = privkey_to_bytes(private_key) @@ -30,6 +35,7 @@ def __init__(self, private_key: str | bytes | Ed25519PrivateKey | None = None): self.sk = Ed25519PrivateKey.generate() self.vk = self.sk.public_key() + self._sub_account = sub_account def __repr__(self) -> str: return f"" @@ -43,6 +49,10 @@ def public_key(self) -> UserPublicKey: def public_key_vec(self) -> list[int]: return str_to_vec(self.public_key) + @property + def sub_account(self) -> int: + return self._sub_account or MAIN_SUB_ACCOUNT + # Legacy: use `.public_key` (cached). def pubkey(self) -> str: return self.vk.public_bytes(Encoding.Raw, PublicFormat.Raw).hex() From 6c627d7f57b47e3eb97ceebf7391d70b8e6f922a Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 24 Feb 2026 08:17:54 -0600 Subject: [PATCH 2/3] bump eip --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fca350d..8f4bf6e 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ "ape-tokens", "click", "eip712>=0.3.1", - "eth-ape>=0.8.32,<0.9", + "eth-ape>=0.8.48,<0.9", "hexbytes>=1.3.1,<2", ], }, From c50b8b7cf8661d8a15097935ccd0d82e2660778f Mon Sep 17 00:00:00 2001 From: antazoey Date: Tue, 24 Feb 2026 09:59:40 -0600 Subject: [PATCH 3/3] fixes from self review --- tplus/evm/contracts.py | 8 ++++---- tplus/evm/managers/registry.py | 1 - tplus/evm/managers/settle.py | 7 +++++++ 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tplus/evm/contracts.py b/tplus/evm/contracts.py index c2a6a5f..406bf17 100644 --- a/tplus/evm/contracts.py +++ b/tplus/evm/contracts.py @@ -107,13 +107,13 @@ def load_tplus_contracts_project(version: str | None = None) -> "LocalProject": Else, it checks all Ape installed dependencies. If it is not installed, it will fail. Install the tplus-contracts project by running ``ape pm install tpluslabs/tplus-contracts``. """ - if ManagerAccessMixin.local_project.name == "tplus-contracts": + if path := os.environ.get("TPLUS_CONTRACTS_PATH"): + return Project(path) + + elif ManagerAccessMixin.local_project.name == "tplus-contracts": # Working from the t+ contracts repo return ManagerAccessMixin.local_project - elif path := os.environ.get("TPLUS_CONTRACTS_PATH"): - return Project(path) - # Load the project from dependencies. try: project = _load_tplus_contracts_from_dependencies(version=version) diff --git a/tplus/evm/managers/registry.py b/tplus/evm/managers/registry.py index 0a5623e..bf0f588 100644 --- a/tplus/evm/managers/registry.py +++ b/tplus/evm/managers/registry.py @@ -66,7 +66,6 @@ async def set_asset( await wait_for_condition( update_fn=lambda: ce.assets.update(), get_fn=lambda: ce.assets.get(), - # cond: checks if the vault address is part any of the ChainAddress returned. check_fn=lambda assets: f"{index}" in assets, timeout=10, interval=1, diff --git a/tplus/evm/managers/settle.py b/tplus/evm/managers/settle.py index 26778bb..26d8354 100644 --- a/tplus/evm/managers/settle.py +++ b/tplus/evm/managers/settle.py @@ -299,8 +299,15 @@ async def execute_settlement( kwargs.setdefault("required_confirmations", 0) if token_in_address is None: + if settlement_info.asset_in.indexed: + raise ValueError("Missing address for indexed asset-in, please specify.") + token_in_address = settlement_info.asset_in.evm_address + if token_out_address is None: + if settlement_info.asset_out.indexed: + raise ValueError("Missing address for indexed asset-out, please specify.") + token_out_address = settlement_info.asset_out.evm_address # Execute the settlement on-chain.