From e24cd92074cf469a786c12eaea5428684bd8991c Mon Sep 17 00:00:00 2001 From: FP Date: Fri, 25 Aug 2023 12:22:02 -0700 Subject: [PATCH 1/5] feat: optimize internal calls --- brownie-config.yaml | 2 +- contracts/VM.sol | 32 +++++++++++++++------ contracts/test/TestableVMWithMath.sol | 41 +++++++++++++++++++++++++++ tests/conftest.py | 4 +++ tests/test_weiroll_local.py | 20 +++++++++++++ 5 files changed, 90 insertions(+), 9 deletions(-) create mode 100644 contracts/test/TestableVMWithMath.sol create mode 100644 tests/test_weiroll_local.py diff --git a/brownie-config.yaml b/brownie-config.yaml index 17e0433..1666623 100644 --- a/brownie-config.yaml +++ b/brownie-config.yaml @@ -13,7 +13,7 @@ compiler: version: 0.8.11 remappings: - "@openzeppelin=OpenZeppelin/openzeppelin-contracts@4.1.0" - + viaIR: true optimizer: details: yul: true diff --git a/contracts/VM.sol b/contracts/VM.sol index f4558a0..d73de20 100644 --- a/contracts/VM.sol +++ b/contracts/VM.sol @@ -29,6 +29,8 @@ abstract contract VM { self = address(this); } + function dispatch(bytes memory inputs) internal virtual returns (bool success, bytes memory ret) {} + function _execute(bytes32[] calldata commands, bytes[] memory state) internal returns (bytes[] memory) { @@ -61,14 +63,28 @@ abstract contract VM { ) ); } else if (flags & FLAG_CT_MASK == FLAG_CT_CALL) { - (success, outdata) = address(uint160(uint256(command))).call( // target - // inputs - state.buildInputs( - //selector - bytes4(command), - indices - ) - ); + address _target = address(uint160(uint256(command))); + bytes memory inputs = state.buildInputs( + //selector + bytes4(command), + indices + ); + success = false; + + if (_target == address(this)) { + (success, outdata) = dispatch(inputs); + } + + if (!success) { + (success, outdata) = _target.call( // target + // inputs + state.buildInputs( + //selector + bytes4(command), + indices + ) + ); + } } else if (flags & FLAG_CT_MASK == FLAG_CT_STATICCALL) { (success, outdata) = address(uint160(uint256(command))).staticcall( // target // inputs diff --git a/contracts/test/TestableVMWithMath.sol b/contracts/test/TestableVMWithMath.sol new file mode 100644 index 0000000..2928968 --- /dev/null +++ b/contracts/test/TestableVMWithMath.sol @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.11; + +import "../VM.sol"; + +contract TestableVMWithMath is VM { + function execute(bytes32[] calldata commands, bytes[] memory state) + public + payable + returns (bytes[] memory) + { + return _execute(commands, state); + } + + function sum(uint256 a, uint256 b) external pure returns (uint256) { + return a + b; + } + + + function dispatch(bytes memory inputs) + internal + override + returns (bool _success, bytes memory _ret) + { + bytes4 _selector = bytes4(bytes32(inputs)); + if (this.sum.selector == _selector) { + uint256 a; + uint256 b; + assembly { + a := mload(add(inputs, 36)) + b := mload(add(inputs, 68)) + } + uint256 res = this.sum(a, b); + _ret = new bytes(32); + assembly { + mstore(add(_ret, 32), res) + } + return (true, _ret); + } + } +} diff --git a/tests/conftest.py b/tests/conftest.py index df09379..8623212 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,10 @@ def math(alice, Math): math_brownie = alice.deploy(Math) yield WeirollContract.createLibrary(math_brownie) +@pytest.fixture(scope="module") +def weiroll_vm_with_math(alice, TestableVMWithMath): + vm = alice.deploy(TestableVMWithMath) + yield vm @pytest.fixture(scope="module") def testContract(alice, TestContract): diff --git a/tests/test_weiroll_local.py b/tests/test_weiroll_local.py new file mode 100644 index 0000000..bdd780b --- /dev/null +++ b/tests/test_weiroll_local.py @@ -0,0 +1,20 @@ +from brownie import Contract, accounts, Wei, chain, TestableVM +from weiroll import WeirollContract, WeirollPlanner + + +def test_vm_with_math(weiroll_vm_with_math): + weiroll_vm = weiroll_vm_with_math + whale = accounts.at("0x57757E3D981446D585Af0D9Ae4d7DF6D64647806", force=True) + + planner = WeirollPlanner(weiroll_vm) + sum = planner.call(weiroll_vm, "sum", 1, 2) + sum_2 = planner.call(weiroll_vm, "sum", sum, 3) + sum_2 = planner.call(weiroll_vm, "sum", 3, sum_2) + + cmds, state = planner.plan() + weiroll_tx = weiroll_vm.execute( + cmds, state, {"from": whale, "gas_limit": 8_000_000, "gas_price": 0} + ) + + print(weiroll_tx.return_value) + #assert False From 08c19f17a292c9f55d90c3f0f09ddcb5d2799f4f Mon Sep 17 00:00:00 2001 From: FP Date: Fri, 25 Aug 2023 13:08:10 -0700 Subject: [PATCH 2/5] fix: make sum public for more savings --- contracts/test/TestableVMWithMath.sol | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/contracts/test/TestableVMWithMath.sol b/contracts/test/TestableVMWithMath.sol index 2928968..1e1dbb9 100644 --- a/contracts/test/TestableVMWithMath.sol +++ b/contracts/test/TestableVMWithMath.sol @@ -12,7 +12,7 @@ contract TestableVMWithMath is VM { return _execute(commands, state); } - function sum(uint256 a, uint256 b) external pure returns (uint256) { + function sum(uint256 a, uint256 b) public pure returns (uint256) { return a + b; } @@ -30,7 +30,7 @@ contract TestableVMWithMath is VM { a := mload(add(inputs, 36)) b := mload(add(inputs, 68)) } - uint256 res = this.sum(a, b); + uint256 res = sum(a, b); _ret = new bytes(32); assembly { mstore(add(_ret, 32), res) From a33b67cc18afd7ecabeccdc593027c4178a1b8b3 Mon Sep 17 00:00:00 2001 From: FP Date: Mon, 28 Aug 2023 10:33:43 -0700 Subject: [PATCH 3/5] feat: local dispatch flag --- contracts/VM.sol | 29 ++++++++++---------- contracts/test/TestableVMWithMath.sol | 38 +++++++++++++++++++++++++++ tests/test_weiroll_local.py | 29 ++++++++++++++++---- weiroll.py | 15 ++++++++++- 4 files changed, 91 insertions(+), 20 deletions(-) diff --git a/contracts/VM.sol b/contracts/VM.sol index d73de20..a849180 100644 --- a/contracts/VM.sol +++ b/contracts/VM.sol @@ -15,6 +15,8 @@ abstract contract VM { uint256 constant FLAG_EXTENDED_COMMAND = 0x40; uint256 constant FLAG_TUPLE_RETURN = 0x80; + uint256 constant FLAG_LOCAL_DISPATCH = 0x20; // custom local dispatch flag + uint256 constant SHORT_COMMAND_FILL = 0x000000000000FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF; address immutable self; @@ -63,20 +65,19 @@ abstract contract VM { ) ); } else if (flags & FLAG_CT_MASK == FLAG_CT_CALL) { - address _target = address(uint160(uint256(command))); - bytes memory inputs = state.buildInputs( - //selector - bytes4(command), - indices - ); - success = false; - - if (_target == address(this)) { - (success, outdata) = dispatch(inputs); - } - - if (!success) { - (success, outdata) = _target.call( // target + if (flags & FLAG_LOCAL_DISPATCH != 0) { + address _target = address(uint160(uint256(command))); + require(_target == address(this), "_execute: local dispatch must target VM."); + (success, outdata) = dispatch( + // inputs + state.buildInputs( + //selector + bytes4(command), + indices + ) + ); + } else { + (success, outdata) = address(uint160(uint256(command))).call( // target // inputs state.buildInputs( //selector diff --git a/contracts/test/TestableVMWithMath.sol b/contracts/test/TestableVMWithMath.sol index 1e1dbb9..58c33fc 100644 --- a/contracts/test/TestableVMWithMath.sol +++ b/contracts/test/TestableVMWithMath.sol @@ -16,6 +16,14 @@ contract TestableVMWithMath is VM { return a + b; } + function sum3(uint256 a, uint256 b, uint256 c) public pure returns (uint256) { + return a + b; + } + + function sub(uint256 a, uint256 b) public pure returns (uint256) { + return a - b; + } + function dispatch(bytes memory inputs) internal @@ -37,5 +45,35 @@ contract TestableVMWithMath is VM { } return (true, _ret); } + if (this.sub.selector == _selector) { + uint256 a; + uint256 b; + assembly { + a := mload(add(inputs, 36)) + b := mload(add(inputs, 68)) + } + uint256 res = sub(a, b); + _ret = new bytes(32); + assembly { + mstore(add(_ret, 32), res) + } + return (true, _ret); + } + if (this.sum3.selector == _selector) { + uint256 a; + uint256 b; + uint256 c; + assembly { + a := mload(add(inputs, 36)) + b := mload(add(inputs, 68)) + c := mload(add(inputs, 100)) + } + uint256 res = sum3(a, b, c); + _ret = new bytes(32); + assembly { + mstore(add(_ret, 32), res) + } + return (true, _ret); + } } } diff --git a/tests/test_weiroll_local.py b/tests/test_weiroll_local.py index bdd780b..9242a1f 100644 --- a/tests/test_weiroll_local.py +++ b/tests/test_weiroll_local.py @@ -7,14 +7,33 @@ def test_vm_with_math(weiroll_vm_with_math): whale = accounts.at("0x57757E3D981446D585Af0D9Ae4d7DF6D64647806", force=True) planner = WeirollPlanner(weiroll_vm) - sum = planner.call(weiroll_vm, "sum", 1, 2) - sum_2 = planner.call(weiroll_vm, "sum", sum, 3) - sum_2 = planner.call(weiroll_vm, "sum", 3, sum_2) + w_math = WeirollContract.createContract(weiroll_vm) + sum = planner.add(w_math.sum(1, 2)) + sum_2 = planner.add(w_math.sum(3, sum)) + sum_3 = planner.add(w_math.sum3(3, sum_2, 4)) + planner.add(w_math.sub(sum_3, 3)) cmds, state = planner.plan() weiroll_tx = weiroll_vm.execute( cmds, state, {"from": whale, "gas_limit": 8_000_000, "gas_price": 0} ) - print(weiroll_tx.return_value) - #assert False + assert False + +def test_vm_with_math_local_dispatch(weiroll_vm_with_math): + weiroll_vm = weiroll_vm_with_math + whale = accounts.at("0x57757E3D981446D585Af0D9Ae4d7DF6D64647806", force=True) + + planner = WeirollPlanner(weiroll_vm) + w_math = WeirollContract.createContract(weiroll_vm) + sum = planner.add(w_math.sum(1, 2).localDispatch()) + sum_2 = planner.add(w_math.sum(3, sum).localDispatch()) + sum_3 = planner.add(w_math.sum3(3, sum_2, 4).localDispatch()) + planner.add(w_math.sub(sum_3, 3).localDispatch()) + + cmds, state = planner.plan() + weiroll_tx = weiroll_vm.execute( + cmds, state, {"from": whale, "gas_limit": 8_000_000, "gas_price": 0} + ) + + assert False diff --git a/weiroll.py b/weiroll.py index 4a609b8..62780f2 100644 --- a/weiroll.py +++ b/weiroll.py @@ -167,7 +167,8 @@ class CommandFlags(IntFlag): EXTENDED_COMMAND = 0x40 # Specifies that the return value of this call should be wrapped in a `bytes`. Internal use only. TUPLE_RETURN = 0x80 - + # Specifies to use fast local dispatcher. Custom flag. + LOCAL_DISPATCH = 0x20 class FunctionCall: def __init__(self, contract, flags: CommandFlags, fragment: FunctionFragment, args, callvalue=0): @@ -220,6 +221,18 @@ def staticcall(self): self.callvalue, ) + def localDispatch(self): + """ + Returns a new [[FunctionCall]] with local dispatch specified + """ + return self.__class__( + self.contract, + self.flags | CommandFlags.LOCAL_DISPATCH, + self.fragment, + self.args, + self.callvalue, + ) + def isDynamicType(param) -> bool: return eth_abi.grammar.parse(param).is_dynamic From 97d42df06f4e42705635ccfd9403726406d01c7e Mon Sep 17 00:00:00 2001 From: FP Date: Tue, 12 Sep 2023 11:22:39 -0700 Subject: [PATCH 4/5] feat: auto_local_dispatch support --- pyproject.toml | 3 +++ weiroll.py | 56 +++++++++++++++++++++++++------------------------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b9968f9..6271728 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,3 +17,6 @@ eth-brownie = "^v1.17" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.black] +line-length = 120 diff --git a/weiroll.py b/weiroll.py index 62780f2..35a44d3 100644 --- a/weiroll.py +++ b/weiroll.py @@ -11,7 +11,7 @@ from brownie.network.contract import OverloadedMethod from hexbytes import HexBytes -MAX_UINT256 = 2**256-1 +MAX_UINT256 = 2 ** 256 - 1 # TODO: real types? Value = namedtuple("Value", "param") @@ -25,7 +25,6 @@ def simple_type_strings(inputs) -> tuple[Optional[list[str]], Optional[list[int] related: https://github.com/weiroll/weiroll.js/pull/34 """ - if not inputs: return None, None @@ -170,8 +169,16 @@ class CommandFlags(IntFlag): # Specifies to use fast local dispatcher. Custom flag. LOCAL_DISPATCH = 0x20 + class FunctionCall: - def __init__(self, contract, flags: CommandFlags, fragment: FunctionFragment, args, callvalue=0): + def __init__( + self, + contract, + flags: CommandFlags, + fragment: FunctionFragment, + args, + callvalue=0, + ): self.contract = contract self.flags = flags self.fragment = fragment @@ -223,7 +230,7 @@ def staticcall(self): def localDispatch(self): """ - Returns a new [[FunctionCall]] with local dispatch specified + Returns a new [[FunctionCall]] with local dispatch specified """ return self.__class__( self.contract, @@ -334,7 +341,6 @@ def _overload(*args, fn_name=name): self.functionsBySignature[signature] = plan_fn - @classmethod @cache def createContract( @@ -419,31 +425,11 @@ def padArray(a, length, padValue) -> list: class WeirollPlanner: - def __init__(self, clone): + def __init__(self, address: str | None = None, auto_local_dispatch: bool = False): self.state = StateValue() self.commands: list[Command] = [] - self.unlimited_approvals = set() - - self.clone = clone - - def approve(self, token: brownie.Contract, spender: str, wei_needed, approve_wei=None) -> Optional[ReturnValue]: - key = (token, self.clone, spender) - - if approve_wei is None: - approve_wei = MAX_UINT256 - - if key in self.unlimited_approvals and approve_wei != 0: - # we already planned an infinite approval for this token (and we aren't trying to set the approval to 0) - return - - # check current allowance - if token.allowance(self.clone, spender) >= wei_needed: - return - - if approve_wei == MAX_UINT256: - self.unlimited_approvals.add(key) - - return self.call(token, "approve", spender, approve_wei) + self.address: str | None = address + self.auto_local_dispatch = auto_local_dispatch def call(self, brownieContract: brownie.Contract, func_name, *args): """func_name can be just the name, or it can be the full signature. @@ -493,6 +479,20 @@ def add(self, call: FunctionCall) -> Optional[ReturnValue]: * @param call The [[FunctionCall]] to add to the planner * @returns An object representing the return value of the call, or null if it does not return a value. """ + + # Use auto local dispatch if: + # - its enabled + # - the planner has a vm address set + # - the FunctionCall is a Call command + # - the targeted contract is the weiroll vm + if ( + self.auto_local_dispatch + and self.address is not None + and (call.flags & CommandFlags.CALLTYPE_MASK) == CommandFlags.CALL + and call.contract == self.address + ): + call = call.localDispatch() + command = Command(call, CommandType.CALL) self.commands.append(command) From 63e284711cb2389309ab1695fcb0d83ad7af2361 Mon Sep 17 00:00:00 2001 From: FP Date: Tue, 12 Sep 2023 12:00:06 -0700 Subject: [PATCH 5/5] chore: test auto_local_dispatch --- tests/test_weiroll.py | 70 +++++++++---------------------------- tests/test_weiroll_local.py | 27 ++++++++++++-- weiroll.py | 4 +-- 3 files changed, 43 insertions(+), 58 deletions(-) diff --git a/tests/test_weiroll.py b/tests/test_weiroll.py index 8e178e0..2005536 100644 --- a/tests/test_weiroll.py +++ b/tests/test_weiroll.py @@ -123,9 +123,7 @@ def test_weiroll_takes_dynamic_arguments(alice, strings): commands, state = planner.plan() assert len(commands) == 1 - assert commands[0] == weiroll.hexConcat( - "0x367bbd780080ffffffffffff", strings.address - ) + assert commands[0] == weiroll.hexConcat("0x367bbd780080ffffffffffff", strings.address) print(state) assert len(state) == 1 @@ -138,9 +136,7 @@ def test_weiroll_returns_dynamic_arguments(alice, strings): commands, state = planner.plan() assert len(commands) == 1 - assert commands[0] == weiroll.hexConcat( - "0xd824ccf3008081ffffffffff", strings.address - ) + assert commands[0] == weiroll.hexConcat("0xd824ccf3008081ffffffffff", strings.address) assert len(state) == 2 assert state[0] == eth_abi.encode_single("string", "Hello, ") @@ -154,12 +150,8 @@ def test_weiroll_takes_dynamic_argument_from_a_return_value(alice, strings): commands, state = planner.plan() assert len(commands) == 2 - assert commands[0] == weiroll.hexConcat( - "0xd824ccf3008081ffffffff81", strings.address - ) - assert commands[1] == weiroll.hexConcat( - "0x367bbd780081ffffffffffff", strings.address - ) + assert commands[0] == weiroll.hexConcat("0xd824ccf3008081ffffffff81", strings.address) + assert commands[1] == weiroll.hexConcat("0x367bbd780081ffffffffffff", strings.address) assert len(state) == 2 assert state[0] == eth_abi.encode_single("string", "Hello, ") @@ -179,9 +171,7 @@ def test_weiroll_func_takes_and_replaces_current_state(alice, testContract): commands, state = planner.plan() assert len(commands) == 1 - assert commands[0] == weiroll.hexConcat( - "0x08f389c800fefffffffffffe", testContract.address - ) + assert commands[0] == weiroll.hexConcat("0x08f389c800fefffffffffffe", testContract.address) assert len(state) == 0 @@ -194,9 +184,7 @@ def test_weiroll_supports_subplan(alice, math, subplanContract): planner.addSubplan(subplanContract.execute(subplanner, subplanner.state)) commands, state = planner.plan() - assert commands == [ - weiroll.hexConcat("0xde792d5f0082fefffffffffe", subplanContract.address) - ] + assert commands == [weiroll.hexConcat("0xde792d5f0082fefffffffffe", subplanContract.address)] assert len(state) == 3 assert state[0] == eth_abi.encode_single("uint", 1) @@ -216,9 +204,7 @@ def test_weiroll_subplan_allows_return_in_parent_scope(alice, math, subplanContr commands, _ = planner.plan() assert len(commands) == 2 # Invoke subplanner - assert commands[0] == weiroll.hexConcat( - "0xde792d5f0083fefffffffffe", subplanContract.address - ) + assert commands[0] == weiroll.hexConcat("0xde792d5f0083fefffffffffe", subplanContract.address) # sum + 3 assert commands[1] == weiroll.hexConcat("0x771602f7000102ffffffffff", math.address) @@ -237,12 +223,8 @@ def test_weiroll_return_values_across_scopes(alice, math, subplanContract): commands, state = planner.plan() assert len(commands) == 2 - assert commands[0] == weiroll.hexConcat( - "0xde792d5f0083fefffffffffe", subplanContract.address - ) - assert commands[1] == weiroll.hexConcat( - "0xde792d5f0084fefffffffffe", subplanContract.address - ) + assert commands[0] == weiroll.hexConcat("0xde792d5f0083fefffffffffe", subplanContract.address) + assert commands[1] == weiroll.hexConcat("0xde792d5f0084fefffffffffe", subplanContract.address) assert len(state) == 5 # TODO: javascript tests were more complex than this @@ -266,43 +248,29 @@ def test_weiroll_add_subplan_needs_args(alice, math, subplanContract): planner = weiroll.WeirollPlanner(alice) - with pytest.raises( - ValueError, match="Subplans must take planner and state arguments" - ): + with pytest.raises(ValueError, match="Subplans must take planner and state arguments"): planner.addSubplan(subplanContract.execute(subplanner, [])) - with pytest.raises( - ValueError, match="Subplans must take planner and state arguments" - ): + with pytest.raises(ValueError, match="Subplans must take planner and state arguments"): planner.addSubplan(subplanContract.execute([], subplanner.state)) -def test_weiroll_doesnt_allow_multiple_subplans_per_call( - alice, math, multiSubplanContract -): +def test_weiroll_doesnt_allow_multiple_subplans_per_call(alice, math, multiSubplanContract): subplanner = weiroll.WeirollPlanner(alice) subplanner.add(math.add(1, 2)) planner = weiroll.WeirollPlanner(alice) with pytest.raises(ValueError, match="Subplans can only take one planner argument"): - planner.addSubplan( - multiSubplanContract.execute(subplanner, subplanner, subplanner.state) - ) + planner.addSubplan(multiSubplanContract.execute(subplanner, subplanner, subplanner.state)) -def test_weiroll_doesnt_allow_state_array_per_call( - alice, math, multiStateSubplanContract -): +def test_weiroll_doesnt_allow_state_array_per_call(alice, math, multiStateSubplanContract): subplanner = weiroll.WeirollPlanner(alice) subplanner.add(math.add(1, 2)) planner = weiroll.WeirollPlanner(alice) with pytest.raises(ValueError, match="Subplans can only take one state argument"): - planner.addSubplan( - multiStateSubplanContract.execute( - subplanner, subplanner.state, subplanner.state - ) - ) + planner.addSubplan(multiStateSubplanContract.execute(subplanner, subplanner.state, subplanner.state)) def test_weiroll_subplan_has_correct_return_type(alice, math, badSubplanContract): @@ -335,9 +303,7 @@ def test_subplans_without_returns(alice, math, readonlySubplanContract): commands, _ = planner.plan() assert len(commands) == 1 - commands[0] == weiroll.hexConcat( - "0xde792d5f0082feffffffffff", readonlySubplanContract.address - ) + commands[0] == weiroll.hexConcat("0xde792d5f0082feffffffffff", readonlySubplanContract.address) def test_read_only_subplans_requirements(alice, math, readonlySubplanContract): @@ -355,9 +321,7 @@ def test_read_only_subplans_requirements(alice, math, readonlySubplanContract): @pytest.mark.xfail(reason="need to write this") def test_plan_with_loop(alice): - target_calldata = ( - "0xc6b6816900000000000000000000000000000000000000000000054b40b1f852bda0" - ) + target_calldata = "0xc6b6816900000000000000000000000000000000000000000000054b40b1f852bda0" """ [ diff --git a/tests/test_weiroll_local.py b/tests/test_weiroll_local.py index 9242a1f..a8e796c 100644 --- a/tests/test_weiroll_local.py +++ b/tests/test_weiroll_local.py @@ -1,7 +1,9 @@ from brownie import Contract, accounts, Wei, chain, TestableVM from weiroll import WeirollContract, WeirollPlanner +import pytest +@pytest.mark.skip() def test_vm_with_math(weiroll_vm_with_math): weiroll_vm = weiroll_vm_with_math whale = accounts.at("0x57757E3D981446D585Af0D9Ae4d7DF6D64647806", force=True) @@ -18,13 +20,13 @@ def test_vm_with_math(weiroll_vm_with_math): cmds, state, {"from": whale, "gas_limit": 8_000_000, "gas_price": 0} ) - assert False +# @pytest.mark.skip() def test_vm_with_math_local_dispatch(weiroll_vm_with_math): weiroll_vm = weiroll_vm_with_math whale = accounts.at("0x57757E3D981446D585Af0D9Ae4d7DF6D64647806", force=True) - planner = WeirollPlanner(weiroll_vm) + planner = WeirollPlanner(weiroll_vm) w_math = WeirollContract.createContract(weiroll_vm) sum = planner.add(w_math.sum(1, 2).localDispatch()) sum_2 = planner.add(w_math.sum(3, sum).localDispatch()) @@ -36,4 +38,23 @@ def test_vm_with_math_local_dispatch(weiroll_vm_with_math): cmds, state, {"from": whale, "gas_limit": 8_000_000, "gas_price": 0} ) - assert False + +# @pytest.mark.skip() +def test_vm_with_math_auto_local_dispatch(weiroll_vm_with_math): + weiroll_vm = weiroll_vm_with_math + whale = accounts.at("0x57757E3D981446D585Af0D9Ae4d7DF6D64647806", force=True) + + planner = WeirollPlanner(weiroll_vm, auto_local_dispatch=True) + w_math = WeirollContract.createContract(weiroll_vm) + sum = planner.add(w_math.sum(1, 2)) + sum_2 = planner.add(w_math.sum(3, sum)) + sum_3 = planner.add(w_math.sum3(3, sum_2, 4)) + planner.add(w_math.sub(sum_3, 3)) + + for command in planner.commands: + assert(command.call.flags >> 5 & 0x1) + + cmds, state = planner.plan() + weiroll_tx = weiroll_vm.execute( + cmds, state, {"from": whale, "gas_limit": 8_000_000, "gas_price": 0} + ) diff --git a/weiroll.py b/weiroll.py index 35a44d3..6a39836 100644 --- a/weiroll.py +++ b/weiroll.py @@ -425,7 +425,7 @@ def padArray(a, length, padValue) -> list: class WeirollPlanner: - def __init__(self, address: str | None = None, auto_local_dispatch: bool = False): + def __init__(self, address: Optional[str] = None, auto_local_dispatch: bool = False): self.state = StateValue() self.commands: list[Command] = [] self.address: str | None = address @@ -489,7 +489,7 @@ def add(self, call: FunctionCall) -> Optional[ReturnValue]: self.auto_local_dispatch and self.address is not None and (call.flags & CommandFlags.CALLTYPE_MASK) == CommandFlags.CALL - and call.contract == self.address + and call.contract.address == self.address ): call = call.localDispatch()