Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"ape-tokens",
"click",
"eip712>=0.3.1",
"eth-ape>=0.8.32,<0.9",
"eth-ape>=0.8.48,<0.9",
"hexbytes>=1.3.1,<2",
],
},
Expand Down
8 changes: 8 additions & 0 deletions tests/model/test_asset_identifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from tplus.model.asset_identifier import AssetIdentifier
from tplus.model.chain_address import ChainAddress


class TestAssetIdentifier:
Expand Down Expand Up @@ -130,3 +131,10 @@ def test_chain_id_too_small(self):
raw_str = "62622E77D1349Face943C6e7D5c01C61465FE1dc@a4b1"
with pytest.raises(ValueError):
_ = AssetIdentifier(raw_str)

def test_validate_from_chain_address(self):
address = ChainAddress.from_str(
"62622E77D1349Face943C6e7D5c01C61465FE1dc@000000000000aa36a7"
)
asset = AssetIdentifier.model_validate(address)
assert asset == address
13 changes: 11 additions & 2 deletions tests/model/test_settlement.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@
class TestInnerSettlementRequest:
def test_from_raw(self, user):
request = InnerSettlementRequest.from_raw(
ASSET_IN, 100, 6, ASSET_OUT, 100, 18, user.public_key, CHAIN_ID
ASSET_IN,
100,
6,
ASSET_OUT,
100,
18,
user.public_key,
CHAIN_ID,
0,
)
assert (
request.asset_in.root
Expand All @@ -35,6 +43,7 @@ class TestTxSettlementRequest:
def settlement(self, user):
return {
"tplus_user": user.public_key,
"sub_account_index": 0,
"settler": user.public_key,
**get_base_settlement_data(),
"chain_id": CHAIN_ID,
Expand All @@ -46,7 +55,7 @@ def test_signing_payload(self, settlement, user):
"""
settlement = TxSettlementRequest(inner=settlement, signature=[])
actual = settlement.signing_payload()
expected = f'{{"tplus_user":"{user.public_key}","settler":"{user.public_key}","asset_in":"62622e77d1349face943c6e7d5c01c61465fe1dc000000000000000000000000@000000000000aa36a7","amount_in":"9f4cfc56cd29b000","asset_out":"58372ab62269a52fa636ad7f200d93999595dcaf000000000000000000000000@000000000000aa36a7","amount_out":"8e1bc9bf04000","chain_id":"000000000000aa36a7"}}'
expected = f'{{"tplus_user":"{user.public_key}","sub_account_index":0,"settler":"{user.public_key}","asset_in":"62622e77d1349face943c6e7d5c01c61465fe1dc000000000000000000000000@000000000000aa36a7","amount_in":"9f4cfc56cd29b000","asset_out":"58372ab62269a52fa636ad7f200d93999595dcaf000000000000000000000000@000000000000aa36a7","amount_out":"8e1bc9bf04000","chain_id":"000000000000aa36a7"}}'
assert actual == expected

# Show it is the same as the inner version.
Expand Down
20 changes: 20 additions & 0 deletions tplus/client/clearingengine/assetregistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,23 @@ async def update_risk_parameters(self):
Request that the clearing engine updates its registered risk parameters.
"""
await self._post("params/update")

async def set_registry_address(self, registry_address: ChainAddress):
"""
Admin-only endpoint for setting the registry address. Used in testing environment.
"""
payload = registry_address.model_dump()
await self._post("admin/registry/update-address", json_data=payload)

async def update_fee_account(self):
"""
Request that the clearing engine update its fee account.
"""
await self._post("fee-account/update")

async def get_fee_account(self) -> str:
"""
Get the fee account.
"""
account = await self._get("fee-account")
return f"{account}"
21 changes: 15 additions & 6 deletions tplus/client/clearingengine/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@

from tplus.client.clearingengine.base import BaseClearingEngineClient
from tplus.model.asset_identifier import AssetIdentifier
from tplus.model.chain_address import ChainAddress
from tplus.model.types import ChainID


def _prep_request(
asset_ids: Sequence[str | AssetIdentifier] | str | AssetIdentifier, chains: list[str] | str
asset_ids: Sequence[str | AssetIdentifier] | str | AssetIdentifier, chains: list[ChainID] | str
) -> dict:
asset_ids_seq = [asset_ids] if isinstance(asset_ids, str) else asset_ids
chains = [chains] if isinstance(chains, str) else chains
asset_ids_seq = (
[asset_ids]
if isinstance(asset_ids, str) or isinstance(asset_ids, AssetIdentifier)
else asset_ids
)

# type ignore to avoid type-related bugs when using str instead of ChainID.
chains = [chains] if isinstance(chains, str) else chains # type: ignore

assets = []
for asset in asset_ids_seq:
if isinstance(asset, AssetIdentifier):
if isinstance(asset, ChainAddress):
# Already validated.
assets.append(asset.model_dump())

Expand All @@ -34,7 +43,7 @@ class DecimalClient(BaseClearingEngineClient):
APIs related to decimals.
"""

async def get(self, asset_id: list[str | AssetIdentifier], chains: list[str]) -> dict:
async def get(self, asset_id: list[str | AssetIdentifier], chains: list[ChainID]) -> dict:
"""
Get CE cached decimals for the given assets and chains.

Expand All @@ -48,7 +57,7 @@ async def get(self, asset_id: list[str | AssetIdentifier], chains: list[str]) ->
request = _prep_request(asset_id, chains)
return await self._get("decimals", json_data=request)

async def update(self, asset_id: list[str | AssetIdentifier], chains: list[str]):
async def update(self, asset_id: list[str | AssetIdentifier], chains: list[ChainID]):
"""
Request that the CE update cache decimals for the given assets and chains.
"""
Expand Down
7 changes: 7 additions & 0 deletions tplus/client/clearingengine/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,10 @@ async def get(self) -> list[ChainAddress]:
"""
result: list = await self._get("vaults") or [] # type: ignore
return [ChainAddress.model_validate(a) for a in result]

async def set_credential_manager_address(self, credential_manager: ChainAddress):
"""
Admin-only endpoint for setting the credential manager address. Used in testing environment.
"""
payload = credential_manager.model_dump()
await self._post("admin/credential-manager/update-address", json_data=payload)
110 changes: 85 additions & 25 deletions tplus/evm/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from ape.contracts.base import ContractContainer, ContractInstance
from ape.managers.project import LocalProject

from tplus.utils.user import User

CHAIN_MAP = {
1: "ethereum:mainnet",
11155111: "ethereum:sepolia",
Expand Down Expand Up @@ -105,7 +107,10 @@ def load_tplus_contracts_project(version: str | None = None) -> "LocalProject":
Else, it checks all Ape installed dependencies. If it is not installed, it will fail.
Install the tplus-contracts project by running ``ape pm install tpluslabs/tplus-contracts``.
"""
if ManagerAccessMixin.local_project.name == "tplus-contracts":
if path := os.environ.get("TPLUS_CONTRACTS_PATH"):
return Project(path)

elif ManagerAccessMixin.local_project.name == "tplus-contracts":
# Working from the t+ contracts repo
return ManagerAccessMixin.local_project

Expand Down Expand Up @@ -193,7 +198,7 @@ def __init__(
address: str | None = None,
tplus_contracts_version: str | None = None,
) -> None:
self._deployments: dict[int, ContractInstance] = {}
self._deployments: dict[str, ContractInstance] = {}
self._default_deployer = default_deployer

if isinstance(chain_id, int):
Expand All @@ -202,9 +207,12 @@ def __init__(
self._chain_id = chain_id
self._address = address
self._tplus_contracts_version = tplus_contracts_version
self._attempted_deploy_dev = False

if address is not None and chain_id is not None:
self._deployments[f"{chain_id}"] = self._contract_container.at(address)
self._deployments[f"{chain_id}"] = self._contract_container.at(
address, detect_proxy=False, fetch_from_explorer=False
)

@classmethod
def at(cls, address: str) -> "TPlusContract":
Expand Down Expand Up @@ -233,6 +241,13 @@ def deploy_dev(cls, **kwargs):
owner = kwargs.get("sender") or get_dev_default_owner()
return cls.deploy(owner, sender=owner)

def deploy_dev_and_set_deployment(self) -> "TPlusContract":
self._attempted_deploy_dev = True
instance = self.deploy_dev()
self._address = instance.address
self._deployments[f"{instance.chain_id}"] = instance
return instance

def __repr__(self) -> str:
return f"<{self.name}>"

Expand All @@ -241,14 +256,18 @@ def __getattr__(self, attr_name: str):
# First, try a regular attribute on the class
return self.__getattribute__(attr_name)
except AttributeError:
# Resort to something defined on the contract.
if attr_name.startswith("_"):
# Ignore internals, causes integration issues.
raise

# Try something defined on the contract.
return getattr(self.contract, attr_name)

@property
def name(self) -> str:
return self.__class__.NAME

@property
@cached_property
def chain_id(self) -> ChainID:
return self._chain_id or ChainID.evm(self.chain_manager.chain_id)

Expand All @@ -257,8 +276,7 @@ def address(self) -> str:
if address := self._address:
return address

chain_id = self._chain_id or ChainID.evm(self.chain_manager.chain_id)
return self.get_address(chain_id=chain_id)
return self.get_address(chain_id=self.chain_id)

@property
def tplus_contracts_project(self) -> "Project":
Expand All @@ -277,15 +295,16 @@ def contract(self) -> "ContractInstance":
try:
return self.get_contract()
except ContractNotExists:
if self.chain_manager.provider.network.is_local:
if self.is_local_network and not self._attempted_deploy_dev:
# If simulating, deploy it now.
instance = self.deploy_dev()
self._address = instance.address
self._deployments[self.chain_manager.chain_id] = instance
return instance
return self.deploy_dev_and_set_deployment()

raise # This error.

@property
def is_local_network(self) -> bool:
return self.chain_manager.provider.network.is_local

@property
def default_deployer(self) -> AccountAPI:
if deployer := self._default_deployer:
Expand Down Expand Up @@ -322,13 +341,15 @@ def get_contract(self, chain_id: ChainID | None = None) -> "ContractInstance":
Returns:
ContractInstance
"""
chain_id = chain_id or self._chain_id or ChainID.evm(self.chain_manager.chain_id)
chain_id = chain_id or self.chain_id
if chain_id in self._deployments:
# Get previously cached instance.
return self._deployments[chain_id]

address = self.get_address(chain_id=chain_id)
contract_container = self._contract_container.at(address)
contract_container = self._contract_container.at(
address, detect_proxy=False, fetch_from_explorer=False
)

# Cache for next time.
self._deployments[chain_id] = contract_container
Expand All @@ -339,11 +360,18 @@ def get_address(self, chain_id: ChainID | None = None) -> str:
if self._address and self._chain_id and chain_id == self._chain_id:
return self._address

chain_id = chain_id or self._chain_id or ChainID.evm(self.chain_manager.chain_id)
chain_id = chain_id or self.chain_id
try:
return TPLUS_DEPLOYMENTS[chain_id][self.name]
except KeyError:
raise ContractNotExists(f"{self.name} not deployed on chain '{chain_id}'.")
except KeyError as err:
if self.is_local_network and not self._attempted_deploy_dev:
try:
return self.deploy_dev_and_set_deployment().address
except Exception:
# Raise `ContractNotExists` below.
pass

raise ContractNotExists(f"{self.name} not deployed on chain '{chain_id}'.") from err


class Registry(TPlusContract):
Expand Down Expand Up @@ -383,17 +411,25 @@ def set_asset(
self,
index: int,
asset_address: HexBytes32 | AddressType,
chain_id: int,
chain_id: ChainID,
max_deposit: int,
max_1hr_deposits: int,
min_weight: int,
sender=None,
) -> None:
if isinstance(asset_address, str) and len(asset_address) <= 42:
# Given EVM style address. Store as right-padded address.
asset_address = to_bytes32(asset_address, pad="r")

return self.contract.setAssetData(
index, (asset_address, chain_id, max_deposit), sender=sender
)
data = {
"index": index,
"assetAddress": asset_address,
"chainId": {"routingId": chain_id.routing_id, "vmId": chain_id.vm_id},
"maxDeposits": max_deposit,
"max1hrDeposits": max_1hr_deposits,
"minWeight": min_weight,
}
return self.contract.setAssetData(data, sender=sender)


class DepositVault(TPlusContract):
Expand Down Expand Up @@ -435,6 +471,24 @@ def add_settler_executor(
) -> "ReceiptAPI":
return self.addSettlerExecutor(settler, executor, **kwargs)

def get_settlement_count(self, user: "UserPublicKey | User", account_index: int) -> int:
if not isinstance(user, UserPublicKey):
user = user.public_key

return self.contract.settlementCounts(user, account_index)

def get_deposit_count(self, user: "UserPublicKey | User", account_index: int) -> int:
if not isinstance(user, UserPublicKey):
user = user.public_key

return self.contract.depositCounts(user, account_index)

def get_withdrawal_count(self, user: "UserPublicKey | User", account_index: int) -> int:
if not isinstance(user, UserPublicKey):
user = user.public_key

return self.contract.withdrawalCounts(user, account_index)

def deposit(
self,
user: UserPublicKey,
Expand All @@ -455,18 +509,19 @@ def execute_atomic_settlement(
self,
settlement: dict,
user: UserPublicKey,
sub_account: int,
expiry: int,
data: HexBytes,
signature: HexBytes,
**tx_kwargs,
) -> "ReceiptAPI":
try:
return self.contract.executeAtomicSettlement(
settlement, user, expiry, data, signature, **tx_kwargs
settlement, user, sub_account, expiry, data, signature, **tx_kwargs
)
except Exception as err:
err_id = getattr(err, "message", "")
if erc20_err_name := _decode_erc20_error(err.message):
if erc20_err_name := _decode_erc20_error(getattr(err, "message", f"{err}")):
raise ContractLogicError(erc20_err_name) from err

elif err_id == "0x203d82d8":
Expand Down Expand Up @@ -500,7 +555,9 @@ def deploy_dev(cls, sender: AccountAPI | None = None, **kwargs) -> TPlusContract
"""
Deploy and set up a development vault.
"""
credman = kwargs.get("credential_manager") or ZERO_ADDRESS
if not (credman := kwargs.get("credential_manager")):
credman = credential_manager.address

sender = sender or cls.account_manager.test_accounts[0]
contract = cast(DepositVault, cls.deploy(sender, credman, sender=sender))

Expand Down Expand Up @@ -553,7 +610,10 @@ def deploy_dev(cls, **kwargs) -> "ReceiptAPI":
owner = kwargs.get("sender") or get_dev_default_owner()
operators = kwargs.get("operators", [owner.address])
threshold = kwargs.get("quorum_threshold") or len(operators)
registry_address = kwargs.get("registry") or ZERO_ADDRESS

if not (registry_address := kwargs.get("registry")):
registry_address = registry.address

measurements = kwargs.get("measurements") or []
automata_verifier = kwargs.get("automata_verifier") or ZERO_ADDRESS

Expand Down
Loading