diff --git a/pycardano/plutus.py b/pycardano/plutus.py index 90a83ca2..435c8f4f 100644 --- a/pycardano/plutus.py +++ b/pycardano/plutus.py @@ -1064,25 +1064,52 @@ def plutus_script_hash(script: Union[NativeScript, PlutusScript]) -> ScriptHash: return script_hash(script) -class PlutusScript(bytes): +class PlutusScript(CBORSerializable, bytes): + """ + Plutus script class. + + This class is a base class for all Plutus script versions. + + Example - Load a Plutus script from `test/resources/scriptV2.plutus `_ and get its address: # noqa: E501 + + + >>> from pycardano import Address, Network + >>> script = PlutusV2Script.load("test/resources/scriptV2.plutus") + >>> Address(plutus_script_hash(script), network=Network.TESTNET).encode() + 'addr_test1wrmz3pjz4dmfxj0fc0a0eyw69tp6h7mpndzf9g3kttq9cqqqw47ym' + """ + @property def version(self) -> int: raise NotImplementedError("") + def to_shallow_primitive(self) -> bytes: + return bytes(self) + + @classmethod + def from_primitive( + cls: Type[PlutusScript], value: Any, type_args: Optional[tuple] = None + ) -> PlutusScript: + if not isinstance(value, (bytes, bytearray)): + raise DeserializeException(f"Expect bytes, got {type(value)} instead.") + return cls(value) + @classmethod def from_version(cls, version: int, script_data: bytes) -> "PlutusScript": - if version == 1: - return PlutusV1Script(script_data) - elif version == 2: - return PlutusV2Script(script_data) - elif version == 3: - return PlutusV3Script(script_data) - else: + class_name = f"PlutusV{version}Script" + script_class = globals().get(class_name) + + if script_class is None: raise ValueError(f"No Plutus script class found for version {version}") + return script_class(script_data) + def get_script_hash_prefix(self) -> bytes: raise NotImplementedError("") + def __repr__(self): + return f"{self.__class__.__name__}({self.hex()})" + class PlutusV1Script(PlutusScript): def get_script_hash_prefix(self) -> bytes: diff --git a/pycardano/serialization.py b/pycardano/serialization.py index a82e5455..63e13586 100644 --- a/pycardano/serialization.py +++ b/pycardano/serialization.py @@ -720,14 +720,6 @@ def _restore_typed_primitive( if not isinstance(v, bytes): raise DeserializeException(f"Expected type bytes but got {type(v)}") return ByteString(v) - elif isclass(t) and t.__name__ in [ - "PlutusV1Script", - "PlutusV2Script", - "PlutusV3Script", - ]: - if not isinstance(v, bytes): - raise DeserializeException(f"Expected type bytes but got {type(v)}") - return t(v) elif hasattr(t, "__origin__") and (t.__origin__ is dict): t_args = t.__args__ if len(t_args) != 2: diff --git a/test/pycardano/test_plutus.py b/test/pycardano/test_plutus.py index 26a5ac08..825f3464 100644 --- a/test/pycardano/test_plutus.py +++ b/test/pycardano/test_plutus.py @@ -9,13 +9,14 @@ import pytest from cbor2 import CBORTag -from pycardano import TransactionWitnessSet +from pycardano import Address, Network, TransactionWitnessSet from pycardano.exception import DeserializeException from pycardano.plutus import ( COST_MODELS, Datum, ExecutionUnits, PlutusData, + PlutusV2Script, RawPlutusData, Redeemer, RedeemerKey, @@ -593,3 +594,15 @@ def test_empty_map_deser(): serialized = witness.to_primitive() deserialized = TransactionWitnessSet.from_primitive(serialized) assert deserialized.redeemer == empty_map + + +def test_load_plutus_script(): + script = PlutusV2Script.load("test/resources/scriptV2.plutus") + assert ( + script.to_cbor_hex() + == "581a581801000022232632498cd5ce2481064255524e542100120011" + ) + assert ( + Address(plutus_script_hash(script), network=Network.TESTNET).encode() + == "addr_test1wrmz3pjz4dmfxj0fc0a0eyw69tp6h7mpndzf9g3kttq9cqqqw47ym" + ) diff --git a/test/resources/scriptV2.plutus b/test/resources/scriptV2.plutus new file mode 100644 index 00000000..25e7b359 --- /dev/null +++ b/test/resources/scriptV2.plutus @@ -0,0 +1,5 @@ +{ + "type": "PlutusScriptV2", + "description": "", + "cborHex": "581a581801000022232632498cd5ce2481064255524e542100120011" +}