From 7c00bbcf84678c6b1beb4cbb8357d2f808298433 Mon Sep 17 00:00:00 2001 From: chao-peng-story Date: Thu, 5 Feb 2026 18:27:14 +0800 Subject: [PATCH] feat: add is_registered and get_balance public methods - Expose is_registered as public method in IPAsset class with input validation - Add get_balance(address) method to StoryClient for querying any address balance - Update get_wallet_balance to reuse get_balance method - Add proper input validation with clear error messages --- .../resources/IPAsset.py | 30 ++++++++++++------- src/story_protocol_python_sdk/story_client.py | 23 ++++++++++++-- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/story_protocol_python_sdk/resources/IPAsset.py b/src/story_protocol_python_sdk/resources/IPAsset.py index 2b0e793..58f4237 100644 --- a/src/story_protocol_python_sdk/resources/IPAsset.py +++ b/src/story_protocol_python_sdk/resources/IPAsset.py @@ -199,7 +199,7 @@ def register( """ try: ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): + if self.is_registered(ip_id): return {"tx_hash": None, "ip_id": ip_id} req_object: dict = { @@ -316,7 +316,7 @@ def register_derivative( :return dict: A dictionary with the transaction hash """ try: - if not self._is_registered(child_ip_id): + if not self.is_registered(child_ip_id): raise ValueError( f"The child IP with id {child_ip_id} is not registered." ) @@ -378,7 +378,7 @@ def register_derivative_with_license_tokens( validate_max_rts(max_rts) # Validate child IP registration - if not self._is_registered(child_ip_id): + if not self.is_registered(child_ip_id): raise ValueError( f"The child IP with id {child_ip_id} is not registered." ) @@ -757,7 +757,7 @@ def register_ip_and_attach_pil_terms( """ try: ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): + if self.is_registered(ip_id): raise ValueError( f"The NFT with id {token_id} is already registered as IP." ) @@ -872,7 +872,7 @@ def register_derivative_ip( """ try: ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): + if self.is_registered(ip_id): raise ValueError( f"The NFT with id {token_id} is already registered as IP." ) @@ -1061,7 +1061,7 @@ def register_ip_and_make_derivative_with_license_tokens( """ try: ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): + if self.is_registered(ip_id): raise ValueError( f"The NFT with id {token_id} is already registered as IP." ) @@ -1290,7 +1290,7 @@ def register_derivative_ip_and_attach_pil_terms_and_distribute_royalty_tokens( try: nft_contract = validate_address(nft_contract) ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): + if self.is_registered(ip_id): raise ValueError( f"The NFT with id {token_id} is already registered as IP." ) @@ -1397,7 +1397,7 @@ def register_ip_and_attach_pil_terms_and_distribute_royalty_tokens( try: nft_contract = validate_address(nft_contract) ip_id = self._get_ip_id(nft_contract, token_id) - if self._is_registered(ip_id): + if self.is_registered(ip_id): raise ValueError( f"The NFT with id {token_id} is already registered as IP." ) @@ -1500,7 +1500,7 @@ def register_pil_terms_and_attach( :return RegisterPILTermsAndAttachResponse: Dictionary with the tx hash and license terms IDs. """ try: - if not self._is_registered(ip_id): + if not self.is_registered(ip_id): raise ValueError(f"The IP with id {ip_id} is not registered.") calculated_deadline = self.sign_util.get_deadline(deadline=deadline) ip_account_impl_client = IPAccountImplClient(self.web3, ip_id) @@ -2009,7 +2009,7 @@ def _validate_derivative_data(self, derivative_data: dict) -> dict: for parent_id, terms_id in zip( internal_data["parentIpIds"], internal_data["licenseTermsIds"] ): - if not self._is_registered(parent_id): + if not self.is_registered(parent_id): raise ValueError( f"The parent IP with id {parent_id} is not registered." ) @@ -2134,13 +2134,21 @@ def _get_ip_id(self, token_contract: str, token_id: int) -> str: self.chain_id, token_contract, token_id ) - def _is_registered(self, ip_id: str) -> bool: + def is_registered(self, ip_id: str) -> bool: """ Check if an IP is registered. :param ip_id str: The IP ID to check. :return bool: True if registered, False otherwise. + :raises ValueError: If the ip_id is empty or has invalid format. """ + if not ip_id: + raise ValueError("is_registered: ip_id is required") + + if not self.web3.is_address(ip_id): + raise ValueError(f"is_registered: invalid IP ID address format: {ip_id}") + + ip_id = self.web3.to_checksum_address(ip_id) return self.ip_asset_registry_client.isRegistered(ip_id) def _parse_tx_ip_registered_event(self, tx_receipt: dict) -> list[RegisteredIP]: diff --git a/src/story_protocol_python_sdk/story_client.py b/src/story_protocol_python_sdk/story_client.py index 51ea6b5..655ac4b 100644 --- a/src/story_protocol_python_sdk/story_client.py +++ b/src/story_protocol_python_sdk/story_client.py @@ -155,14 +155,31 @@ def Group(self) -> Group: self._group = Group(self.web3, self.account, self.chain_id) return self._group + def get_balance(self, address: str) -> int: + """ + Get the native token (IP) balance of the specified address. + + :param address str: The address to query the balance for. + :return int: The native token balance of the specified address in wei. + :raises ValueError: If the address is invalid. + """ + if not address: + raise ValueError("Address must be provided") + + if not self.web3.is_address(address): + raise ValueError(f"Invalid address format: {address}") + + checksum_address = self.web3.to_checksum_address(address) + return self.web3.eth.get_balance(checksum_address) + def get_wallet_balance(self) -> int: """ - Get the WIP token balance of the current wallet. + Get the native token (IP) balance of the current wallet. - :return int: The WIP token balance of the current wallet. + :return int: The native token balance of the current wallet in wei. :raises ValueError: If no account is found. """ if not self.account or not hasattr(self.account, "address"): raise ValueError("No account found in wallet") - return self.web3.eth.get_balance(self.account.address) + return self.get_balance(self.account.address)