diff --git a/tests/functional/syntax/warnings/test_contract_size_limit_warning.py b/tests/functional/syntax/warnings/test_contract_size_limit_warning.py index 3e27304266..3a7b457b5d 100644 --- a/tests/functional/syntax/warnings/test_contract_size_limit_warning.py +++ b/tests/functional/syntax/warnings/test_contract_size_limit_warning.py @@ -15,9 +15,9 @@ def huge_bytestring(): def test_contract_size_exceeded(huge_bytestring): code = f""" @external -def a() -> bool: +def a() -> Bytes[24577]: q: Bytes[24577] = {huge_bytestring} - return True + return q """ with pytest.warns(vyper.warnings.ContractSizeLimit): vyper.compile_code(code, output_formats=["bytecode_runtime"]) diff --git a/tests/hevm.py b/tests/hevm.py index 50847fc362..c6f56d17c8 100644 --- a/tests/hevm.py +++ b/tests/hevm.py @@ -7,6 +7,7 @@ from vyper.ir.compile_ir import assembly_to_evm from vyper.venom import ( CFGNormalization, + ConcretizeMemLocPass, LowerDloadPass, SimplifyCFGPass, SingleUseExpansion, @@ -70,6 +71,7 @@ def _prep_hevm_venom_ctx(ctx, verbose=False): # requirements for venom_to_assembly LowerDloadPass(ac, fn).run_pass() + ConcretizeMemLocPass(ac, fn).run_pass() SingleUseExpansion(ac, fn).run_pass() CFGNormalization(ac, fn).run_pass() diff --git a/tests/unit/compiler/venom/test_abstract_mem.py b/tests/unit/compiler/venom/test_abstract_mem.py new file mode 100644 index 0000000000..67c52985bf --- /dev/null +++ b/tests/unit/compiler/venom/test_abstract_mem.py @@ -0,0 +1,19 @@ +from vyper.venom.basicblock import IRAbstractMemLoc +from vyper.venom.memory_location import ( + MemoryLocation, + MemoryLocationAbstract, + MemoryLocationSegment, +) + + +def test_abstract_may_overlap(): + op1 = IRAbstractMemLoc(256, offset=0, force_id=0) + op2 = IRAbstractMemLoc(256, offset=128, force_id=0) + loc1 = MemoryLocationAbstract( + op=op1, segment=MemoryLocationSegment(_offset=op1.offset, _size=32) + ) + loc2 = MemoryLocationAbstract( + op=op2, segment=MemoryLocationSegment(_offset=op2.offset, _size=32) + ) + + assert not MemoryLocation.may_overlap(loc1, loc2) diff --git a/tests/unit/compiler/venom/test_common_subexpression_elimination.py b/tests/unit/compiler/venom/test_common_subexpression_elimination.py index 9ad1d6d723..26f7688ac9 100644 --- a/tests/unit/compiler/venom/test_common_subexpression_elimination.py +++ b/tests/unit/compiler/venom/test_common_subexpression_elimination.py @@ -357,6 +357,16 @@ def call(callname: str, i: int, var_name: str): %{var_name}1 = add 1, %{callname}1 """ + def call2(callname: str, i: int, var_name: str): + return f""" + %g{2*i} = gas + %{callname}0 = {callname} %g0, 0, 0, 0, 0, 0, 0 + %{var_name}0 = add 1, %{callname}0 + %g{2*i + 1} = gas + %{callname}1 = {callname} %g0, 0, 0, 0, 0, 0, 0 + %{var_name}1 = add 1, %{callname}1 + """ + pre = f""" main: ; staticcall @@ -366,7 +376,7 @@ def call(callname: str, i: int, var_name: str): {call("delegatecall", 1, "d")} ; call - {call("call", 2, "c")} + {call2("call", 2, "c")} sink %s0, %s1, %d0, %d1, %c0, %c1 """ @@ -470,9 +480,7 @@ def test_cse_immutable_queries(opcode): _check_pre_post(pre, post, hevm=opcode != "codesize") -@pytest.mark.parametrize( - "opcode", ("dloadbytes", "extcodecopy", "codecopy", "returndatacopy", "calldatacopy") -) +@pytest.mark.parametrize("opcode", ("dloadbytes", "codecopy", "returndatacopy", "calldatacopy")) def test_cse_other_mem_ops_elimination(opcode): pre = f""" main: @@ -491,6 +499,24 @@ def test_cse_other_mem_ops_elimination(opcode): _check_pre_post(pre, post) +def test_cse_other_mem_ops_elimination_extcodecopy(): + pre = """ + main: + extcodecopy 10, 20, 30, 40 + extcodecopy 10, 20, 30, 40 + stop + """ + + post = """ + main: + extcodecopy 10, 20, 30, 40 + nop + stop + """ + + _check_pre_post(pre, post, hevm=False) + + def test_cse_self_conflicting_effects(): """ Test that expression that have conflict in their own effects diff --git a/tests/unit/compiler/venom/test_concretize_mem.py b/tests/unit/compiler/venom/test_concretize_mem.py new file mode 100644 index 0000000000..cd301ea13e --- /dev/null +++ b/tests/unit/compiler/venom/test_concretize_mem.py @@ -0,0 +1,106 @@ +from tests.venom_utils import PrePostChecker +from vyper.venom.passes import AssignElimination, ConcretizeMemLocPass, Mem2Var + +_check_pre_post = PrePostChecker([ConcretizeMemLocPass], default_hevm=False) +_check_pre_post_mem2var = PrePostChecker([Mem2Var, AssignElimination], default_hevm=False) + + +def test_valid_overlap(): + pre = """ + main: + calldatacopy [3,256], 100, 256 + %1 = mload [3,256] + calldatacopy [4,32], 200, 32 + %2 = mload [4,32] + calldatacopy [3,256], 1000, 256 + %3 = mload [3,256] + sink %1, %2, %3 + """ + post = """ + main: + calldatacopy 64, 100, 256 + %1 = mload 64 + calldatacopy 64, 200, 32 + %2 = mload 64 + calldatacopy 64, 1000, 256 + %3 = mload 64 + sink %1, %2, %3 + """ + + _check_pre_post(pre, post) + + +def test_venom_allocation(): + pre = """ + main: + %ptr = alloca 0, [3,256] + calldatacopy %ptr, 100, 256 + %1 = mload %ptr + sink %1 + """ + + post1 = """ + main: + calldatacopy [3,256], 100, 256 + %1 = mload [3,256] + sink %1 + """ + + post2 = """ + main: + calldatacopy 64, 100, 256 + %1 = mload 64 + sink %1 + """ + + _check_pre_post_mem2var(pre, post1) + _check_pre_post(post1, post2) + + +def test_venom_allocation_branches(): + pre = """ + main: + %ptr1 = alloca 0, [3,256] + %ptr2 = alloca 1, [4,128] + %cond = source + jnz %cond, @then, @else + then: + calldatacopy %ptr1, 100, 256 + %1 = mload %ptr1 + sink %1 + else: + calldatacopy %ptr2, 1000, 64 + %2 = mload %ptr2 + sink %2 + """ + + post1 = """ + main: + %cond = source + jnz %cond, @then, @else + then: + calldatacopy [3,256], 100, 256 + %1 = mload [3,256] + sink %1 + else: + calldatacopy [4,128], 1000, 64 + %2 = mload [4,128] + sink %2 + """ + + post2 = """ + main: + %cond = source + jnz %cond, @then, @else + then: + calldatacopy 64, 100, 256 + %1 = mload 64 + sink %1 + else: + calldatacopy 64, 1000, 64 + %2 = mload 64 + sink %2 + """ + + _check_pre_post_mem2var(pre, post1) + _check_pre_post(post1, post2) diff --git a/tests/unit/compiler/venom/test_dead_store_elimination.py b/tests/unit/compiler/venom/test_dead_store_elimination.py index b5e19ebaab..47ed55c4de 100644 --- a/tests/unit/compiler/venom/test_dead_store_elimination.py +++ b/tests/unit/compiler/venom/test_dead_store_elimination.py @@ -27,7 +27,7 @@ def __init__( self.volatile_locations = volatile_locations def __call__(self, pre: str, post: str, hevm: bool | None = None) -> list[IRPass]: - from vyper.venom.memory_location import MemoryLocation + from vyper.venom.memory_location import MemoryLocationSegment self.pass_objects.clear() @@ -41,7 +41,7 @@ def __call__(self, pre: str, post: str, hevm: bool | None = None) -> list[IRPass mem_ssa = ac.request_analysis(mem_ssa_type_factory(self.addr_space)) for address, size in self.volatile_locations: - volatile_loc = MemoryLocation(offset=address, size=size, is_volatile=True) + volatile_loc = MemoryLocationSegment(_offset=address, _size=size, _is_volatile=True) mem_ssa.mark_location_volatile(volatile_loc) for p in self.passes: @@ -74,25 +74,26 @@ def _check_no_change(code, hevm=False): return _check_pre_post(code, code, hevm=hevm) -def test_basic_dead_store(): - pre = """ +@pytest.mark.parametrize("position", [0, "[0,32]"]) +def test_basic_dead_store(position): + pre = f""" _global: %val1 = 42 %val2 = 24 - mstore 0, %val1 ; Dead store - overwritten before read - mstore 0, 10 ; Dead store - overwritten before read - mstore 0, %val2 - %loaded = mload 0 ; Only reads val2 + mstore {position}, %val1 ; Dead store - overwritten before read + mstore {position}, 10 ; Dead store - overwritten before read + mstore {position}, %val2 + %loaded = mload {position} ; Only reads val2 stop """ - post = """ + post = f""" _global: %val1 = 42 %val2 = 24 nop nop - mstore 0, %val2 - %loaded = mload 0 + mstore {position}, %val2 + %loaded = mload {position} stop """ _check_pre_post(pre, post) @@ -117,49 +118,54 @@ def test_basic_not_dead_store(): _check_pre_post(pre, post) -def test_basic_not_dead_store_with_mload(): - pre = """ +@pytest.mark.parametrize("positions", [(0, 32), ("[0,32]", "[1,32]")]) +def test_basic_not_dead_store_with_mload(positions): + a, b = positions + pre = f""" _global: %1 = source - mstore 0, 1 - mstore 32, 2 - %2 = mload 0 + mstore {a}, 1 + mstore {b}, 2 + %2 = mload {a} stop """ - post = """ + post = f""" _global: %1 = source - mstore 0, 1 + mstore {a}, 1 nop - %2 = mload 0 + %2 = mload {a} stop """ _check_pre_post(pre, post) -def test_basic_not_dead_store_with_return(): - pre = """ +@pytest.mark.parametrize("positions", [(0, 32), ("[0,32]", "[1,32]"), ("[2,32]", "[3,32]")]) +def test_basic_not_dead_store_with_return(positions): + a, b = positions + pre = f""" _global: %1 = source - mstore 0, 1 - mstore 32, 2 - return 0, 32 + mstore {a}, 1 + mstore {b}, 2 + return {a}, 32 """ - post = """ + post = f""" _global: %1 = source - mstore 0, 1 + mstore {a}, 1 nop - return 0, 32 + return {a}, 32 """ _check_pre_post(pre, post) -def test_never_read_store(): - pre = """ +@pytest.mark.parametrize("position", [0, 32, "[0,32]", "[1,32]"]) +def test_never_read_store(position): + pre = f""" _global: %val = 42 - mstore 0, %val ; Dead store - never read + mstore {position}, %val ; Dead store - never read stop """ post = """ @@ -171,34 +177,37 @@ def test_never_read_store(): _check_pre_post(pre, post) -def test_live_store(): - pre = """ +@pytest.mark.parametrize("position", [0, 32, "[0,32]", "[1,32]"]) +def test_live_store(position): + pre = f""" _global: %val = 42 - mstore 0, %val - %loaded = mload 0 ; Makes the store live + mstore {position}, %val + %loaded = mload {position} ; Makes the store live stop """ _check_pre_post(pre, pre) # Should not change -def test_dead_store_different_locations(): - pre = """ +@pytest.mark.parametrize("positions", [(0, 32), ("[0,32]", "[1,32]"), ("[2,32]", "[3,32]")]) +def test_dead_store_different_locations(positions): + a, b = positions + pre = f""" _global: %val1 = 42 %val2 = 24 - mstore 0, %val1 ; Dead store - never read - mstore 32, %val2 ; Live store - %loaded = mload 32 + mstore {a}, %val1 ; Dead store - never read + mstore {b}, %val2 ; Live store + %loaded = mload {b} stop """ - post = """ + post = f""" _global: %val1 = 42 %val2 = 24 nop - mstore 32, %val2 - %loaded = mload 32 + mstore {b}, %val2 + %loaded = mload {b} stop """ _check_pre_post(pre, post) diff --git a/tests/unit/compiler/venom/test_load_elimination.py b/tests/unit/compiler/venom/test_load_elimination.py index a6169b5222..72e261a3a3 100644 --- a/tests/unit/compiler/venom/test_load_elimination.py +++ b/tests/unit/compiler/venom/test_load_elimination.py @@ -31,12 +31,16 @@ def _fill_symbolic(addrspace): RW_ADDRESS_SPACES = (MEMORY, STORAGE, TRANSIENT) +@pytest.mark.parametrize("position", [11, "[2,32]"]) @pytest.mark.parametrize("addrspace", ADDRESS_SPACES) -def test_simple_load_elimination(addrspace): +def test_simple_load_elimination(addrspace, position): + if addrspace != MEMORY and not isinstance(position, int): + return + LOAD = addrspace.load_op pre = f""" main: - %ptr = 11 + %ptr = {position} %1 = {LOAD} %ptr %2 = {LOAD} %ptr @@ -44,7 +48,7 @@ def test_simple_load_elimination(addrspace): """ post = f""" main: - %ptr = 11 + %ptr = {position} %1 = {LOAD} %ptr %2 = %1 @@ -53,15 +57,19 @@ def test_simple_load_elimination(addrspace): _check_pre_post(pre, post) +@pytest.mark.parametrize("position", [11, "[2,32]"]) @pytest.mark.parametrize("addrspace", ADDRESS_SPACES) -def test_equivalent_var_elimination(addrspace): +def test_equivalent_var_elimination(addrspace, position): """ Test that the lattice can "peer through" equivalent vars """ + if addrspace != MEMORY and not isinstance(position, int): + return + LOAD = addrspace.load_op pre = f""" main: - %1 = 11 + %1 = {position} %2 = %1 %3 = {LOAD} %1 @@ -71,7 +79,7 @@ def test_equivalent_var_elimination(addrspace): """ post = f""" main: - %1 = 11 + %1 = {position} %2 = %1 %3 = {LOAD} %1 @@ -99,18 +107,25 @@ def test_elimination_barrier(): _check_no_change(pre) +@pytest.mark.parametrize("position", [[55, 11], ["[1,32]", "[2,32]"]]) @pytest.mark.parametrize("addrspace", RW_ADDRESS_SPACES) -def test_store_load_elimination(addrspace): +def test_store_load_elimination(addrspace, position: list): """ Check that lattice stores the result of stores (even through equivalent variables) """ + if addrspace != MEMORY and not isinstance(position, int): + return + LOAD = addrspace.load_op STORE = addrspace.store_op + + val, ptr = position + pre = f""" main: - %val = 55 - %ptr1 = 11 + %val = {val} + %ptr1 = {ptr} %ptr2 = %ptr1 {STORE} %ptr1, %val @@ -120,8 +135,8 @@ def test_store_load_elimination(addrspace): """ post = f""" main: - %val = 55 - %ptr1 = 11 + %val = {val} + %ptr1 = {ptr} %ptr2 = %ptr1 {STORE} %ptr1, %val @@ -150,16 +165,19 @@ def test_store_load_barrier(): _check_no_change(pre) -def test_store_load_overlap_barrier(): +@pytest.mark.parametrize("position", [(10, 20), (32, 63)]) +def test_store_load_overlap_barrier(position: tuple): """ Check for barrier between store/load done by overlap of the mstore and mload """ - pre = """ + ptr_mload, ptr_mstore = position + + pre = f""" main: - %ptr_mload = 10 - %ptr_mstore = 20 + %ptr_mload = {ptr_mload} + %ptr_mstore = {ptr_mstore} %tmp01 = mload %ptr_mload # barrier created with overlap @@ -171,6 +189,37 @@ def test_store_load_overlap_barrier(): _check_no_change(pre) +def test_store_load_pair_memloc(): + """ + Check for barrier between store/load done + by overlap of the mstore and mload + """ + + pre = """ + main: + %ptr_mload = [1,32] + %ptr_mstore = [2,32] + %tmp01 = mload %ptr_mload + + # barrier created with overlap + mstore %ptr_mstore, 11 + %tmp02 = mload %ptr_mload + return %tmp01, %tmp02 + """ + post = """ + main: + %ptr_mload = [1,32] + %ptr_mstore = [2,32] + %tmp01 = mload %ptr_mload + + # barrier created with overlap + mstore %ptr_mstore, 11 + return %tmp01, %tmp01 + """ + + _check_pre_post(pre, post) + + def test_store_store_overlap_barrier(): """ Check for barrier between store/load done @@ -232,21 +281,27 @@ def test_store_load_no_overlap_different_store(): _check_pre_post(pre, post) +@pytest.mark.parametrize("position", [(10, 42), ("[2,32]", "[3,32]")]) @pytest.mark.parametrize("addrspace", RW_ADDRESS_SPACES) -def test_store_store_no_overlap(addrspace): +def test_store_store_no_overlap(addrspace, position: list): """ Test that if the mstores do not overlap it can still eliminate any possible repeated mstores """ + if addrspace != MEMORY and not isinstance(position, int): + return + LOAD = addrspace.load_op STORE = addrspace.store_op + ptr_1, ptr_2 = position + pre = f""" main: {_fill_symbolic(addrspace)} - %ptr_mstore01 = 10 - %ptr_mstore02 = 42 + %ptr_mstore01 = {ptr_1} + %ptr_mstore02 = {ptr_2} {STORE} %ptr_mstore01, 10 {STORE} %ptr_mstore02, 11 @@ -262,8 +317,8 @@ def test_store_store_no_overlap(addrspace): main: {_fill_symbolic(addrspace)} - %ptr_mstore01 = 10 - %ptr_mstore02 = 42 + %ptr_mstore01 = {ptr_1} + %ptr_mstore02 = {ptr_2} {STORE} %ptr_mstore01, 10 {STORE} %ptr_mstore02, 11 @@ -276,15 +331,16 @@ def test_store_store_no_overlap(addrspace): _check_pre_post(pre, post) -def test_store_store_unknown_ptr_barrier(): +@pytest.mark.parametrize("position", [10, "[2,32]"]) +def test_store_store_unknown_ptr_barrier(position: list): """ Check for barrier between store/load done by overlap of the mstore and mload """ - pre = """ + pre = f""" main: - %ptr_mstore01 = 10 + %ptr_mstore01 = {position} %ptr_mstore02 = source mstore %ptr_mstore01, 10 @@ -298,11 +354,12 @@ def test_store_store_unknown_ptr_barrier(): _check_no_change(pre) -def test_simple_load_elimination_inter(): - pre = """ +@pytest.mark.parametrize("position", [5, "[2,32]"]) +def test_simple_load_elimination_inter(position): + pre = f""" main: %par = param - %1 = mload 5 + %1 = mload {position} %cond = iszero %par jnz %cond, @then, @else then: @@ -310,14 +367,14 @@ def test_simple_load_elimination_inter(): else: jmp @join join: - %3 = mload 5 + %3 = mload {position} sink %3 """ - post = """ + post = f""" main: %par = param - %1 = mload 5 + %1 = mload {position} %cond = iszero %par jnz %cond, @then, @else then: @@ -332,33 +389,34 @@ def test_simple_load_elimination_inter(): _check_pre_post(pre, post) -def test_simple_load_elimination_inter_join(): - pre = """ +@pytest.mark.parametrize("position", [5, "[2,32]"]) +def test_simple_load_elimination_inter_join(position): + pre = f""" main: %par = param %cond = iszero %par jnz %cond, @then, @else then: - %1 = mload 5 + %1 = mload {position} jmp @join else: - %2 = mload 5 + %2 = mload {position} jmp @join join: - %3 = mload 5 + %3 = mload {position} sink %3 """ - post = """ + post = f""" main: %par = param %cond = iszero %par jnz %cond, @then, @else then: - %1 = mload 5 + %1 = mload {position} jmp @join else: - %2 = mload 5 + %2 = mload {position} jmp @join join: %4 = phi @then, %1, @else, %2 @@ -369,51 +427,54 @@ def test_simple_load_elimination_inter_join(): _check_pre_post(pre, post) -def test_load_elimination_inter_distant_bb(): - pre = """ +@pytest.mark.parametrize("position", [(5, 1000, 50), ("[2,32]", "[3,32]", "[4,32]")]) +def test_load_elimination_inter_distant_bb(position): + a, b, c = position + + pre = f""" main: %par = param %cond = iszero %par jnz %cond, @then, @else then: - %1 = mload 5 + %1 = mload {a} jmp @join else: - %2 = mload 5 + %2 = mload {a} jmp @join join: - %3 = mload 1000 + %3 = mload {b} %cond_end = iszero %3 jnz %cond_end, @end_a, @end_b end_a: - %4 = mload 5 + %4 = mload {a} sink %4 end_b: - %5 = mload 50 + %5 = mload {c} sink %5 """ - post = """ + post = f""" main: %par = param %cond = iszero %par jnz %cond, @then, @else then: - %1 = mload 5 + %1 = mload {a} jmp @join else: - %2 = mload 5 + %2 = mload {a} jmp @join join: %6 = phi @then, %1, @else, %2 - %3 = mload 1000 + %3 = mload {b} %cond_end = iszero %3 jnz %cond_end, @end_a, @end_b end_a: %4 = %6 sink %4 end_b: - %5 = mload 50 + %5 = mload {c} sink %5 """ diff --git a/tests/unit/compiler/venom/test_mem_alias.py b/tests/unit/compiler/venom/test_mem_alias.py index 94b9f09e29..7e5cd76c28 100644 --- a/tests/unit/compiler/venom/test_mem_alias.py +++ b/tests/unit/compiler/venom/test_mem_alias.py @@ -1,10 +1,10 @@ from vyper.venom.analysis import IRAnalysesCache from vyper.venom.analysis.mem_alias import MemoryAliasAnalysis from vyper.venom.basicblock import IRLabel -from vyper.venom.memory_location import MemoryLocation +from vyper.venom.memory_location import MemoryLocationSegment from vyper.venom.parser import parse_venom -FULL_MEMORY_ACCESS = MemoryLocation(offset=0, size=None) +FULL_MEMORY_ACCESS = MemoryLocationSegment(_offset=0, _size=None) def test_may_alias_full_memory_access(): @@ -20,7 +20,7 @@ def test_may_alias_full_memory_access(): alias = MemoryAliasAnalysis(ac, fn) alias.analyze() - loc1 = MemoryLocation(offset=0, size=32) + loc1 = MemoryLocationSegment(_offset=0, _size=32) assert alias.may_alias( FULL_MEMORY_ACCESS, loc1 ), "FULL_MEMORY_ACCESS should alias with regular location" @@ -29,26 +29,26 @@ def test_may_alias_full_memory_access(): ), "FULL_MEMORY_ACCESS should alias with regular location" assert not alias.may_alias( - FULL_MEMORY_ACCESS, MemoryLocation.EMPTY + FULL_MEMORY_ACCESS, MemoryLocationSegment.EMPTY ), "FULL_MEMORY_ACCESS should not alias with EMPTY_MEMORY_ACCESS" assert not alias.may_alias( - MemoryLocation.EMPTY, FULL_MEMORY_ACCESS + MemoryLocationSegment.EMPTY, FULL_MEMORY_ACCESS ), "FULL_MEMORY_ACCESS should not alias with EMPTY_MEMORY_ACCESS" assert alias.may_alias( FULL_MEMORY_ACCESS, FULL_MEMORY_ACCESS ), "FULL_MEMORY_ACCESS should alias with itself" - loc1 = MemoryLocation(offset=0, size=32) + loc1 = MemoryLocationSegment(_offset=0, _size=32) assert not alias.may_alias( - MemoryLocation.EMPTY, loc1 + MemoryLocationSegment.EMPTY, loc1 ), "EMPTY_MEMORY_ACCESS should not alias with regular location" assert not alias.may_alias( - loc1, MemoryLocation.EMPTY + loc1, MemoryLocationSegment.EMPTY ), "EMPTY_MEMORY_ACCESS should not alias with regular location" assert not alias.may_alias( - MemoryLocation.EMPTY, MemoryLocation.EMPTY + MemoryLocationSegment.EMPTY, MemoryLocationSegment.EMPTY ), "EMPTY_MEMORY_ACCESS should not alias with itself" @@ -65,8 +65,8 @@ def test_may_alias_volatile(): alias = MemoryAliasAnalysis(ac, fn) alias.analyze() - volatile_loc = MemoryLocation(offset=0, size=32, is_volatile=True) - regular_loc = MemoryLocation(offset=0, size=32) + volatile_loc = MemoryLocationSegment(_offset=0, _size=32, _is_volatile=True) + regular_loc = MemoryLocationSegment(_offset=0, _size=32) assert alias.may_alias( volatile_loc, regular_loc ), "Volatile location should alias with overlapping regular location" @@ -74,7 +74,7 @@ def test_may_alias_volatile(): regular_loc, volatile_loc ), "Regular location should alias with overlapping volatile location" - non_overlapping_loc = MemoryLocation(offset=32, size=32) + non_overlapping_loc = MemoryLocationSegment(_offset=32, _size=32) assert not alias.may_alias( volatile_loc, non_overlapping_loc ), "Volatile location should not alias with non-overlapping location" @@ -96,9 +96,9 @@ def test_mark_volatile(): alias = MemoryAliasAnalysis(ac, fn) alias.analyze() - loc1 = MemoryLocation(offset=0, size=32) - loc2 = MemoryLocation(offset=0, size=32) - loc3 = MemoryLocation(offset=32, size=32) + loc1 = MemoryLocationSegment(_offset=0, _size=32) + loc2 = MemoryLocationSegment(_offset=0, _size=32) + loc3 = MemoryLocationSegment(_offset=32, _size=32) alias._analyze_mem_location(loc1) alias._analyze_mem_location(loc2) @@ -141,9 +141,9 @@ def test_may_alias_with_alias_sets(): alias = MemoryAliasAnalysis(ac, fn) alias.analyze() - loc1 = MemoryLocation(offset=0, size=32) - loc2 = MemoryLocation(offset=0, size=32) - loc3 = MemoryLocation(offset=32, size=32) + loc1 = MemoryLocationSegment(_offset=0, _size=32) + loc2 = MemoryLocationSegment(_offset=0, _size=32) + loc3 = MemoryLocationSegment(_offset=32, _size=32) alias._analyze_mem_location(loc1) alias._analyze_mem_location(loc2) @@ -153,7 +153,7 @@ def test_may_alias_with_alias_sets(): assert not alias.may_alias(loc1, loc3), "Locations in different alias sets should not alias" # Test may_alias with new location not in alias sets - loc4 = MemoryLocation(offset=0, size=32) + loc4 = MemoryLocationSegment(_offset=0, _size=32) assert alias.may_alias(loc1, loc4), "New location should alias with existing location" assert loc4 in alias.alias_sets, "New location should be added to alias sets" @@ -172,7 +172,7 @@ def test_mark_volatile_edge_cases(): alias.analyze() # Test marking a location not in alias sets - loc1 = MemoryLocation(offset=0, size=32) + loc1 = MemoryLocationSegment(_offset=0, _size=32) volatile_loc = alias.mark_volatile(loc1) assert volatile_loc.is_volatile, "Marked location should be volatile" assert ( @@ -180,7 +180,7 @@ def test_mark_volatile_edge_cases(): ), "Volatile location should not be in alias sets if original wasn't" # Test marking a location with no aliases - loc2 = MemoryLocation(offset=0, size=32) + loc2 = MemoryLocationSegment(_offset=0, _size=32) alias._analyze_mem_location(loc2) volatile_loc2 = alias.mark_volatile(loc2) assert volatile_loc2 in alias.alias_sets, "Volatile location should be in alias sets" @@ -209,35 +209,35 @@ def test_may_alias_edge_cases(): alias.analyze() assert not alias.may_alias( - FULL_MEMORY_ACCESS, MemoryLocation.EMPTY + FULL_MEMORY_ACCESS, MemoryLocationSegment.EMPTY ), "FULL_MEMORY_ACCESS should not alias with EMPTY_MEMORY_ACCESS" assert not alias.may_alias( - MemoryLocation.EMPTY, FULL_MEMORY_ACCESS + MemoryLocationSegment.EMPTY, FULL_MEMORY_ACCESS ), "EMPTY_MEMORY_ACCESS should not alias with FULL_MEMORY_ACCESS" - loc1 = MemoryLocation(offset=0, size=32) + loc1 = MemoryLocationSegment(_offset=0, _size=32) assert not alias.may_alias( - MemoryLocation.EMPTY, loc1 + MemoryLocationSegment.EMPTY, loc1 ), "EMPTY_MEMORY_ACCESS should not alias with regular location" assert not alias.may_alias( - loc1, MemoryLocation.EMPTY + loc1, MemoryLocationSegment.EMPTY ), "Regular location should not alias with EMPTY_MEMORY_ACCESS" - volatile_loc = MemoryLocation(offset=0, size=32, is_volatile=True) - non_overlapping_loc = MemoryLocation(offset=32, size=32) + volatile_loc = MemoryLocationSegment(_offset=0, _size=32, _is_volatile=True) + non_overlapping_loc = MemoryLocationSegment(_offset=32, _size=32) assert not alias.may_alias( volatile_loc, non_overlapping_loc ), "Volatile location should not alias with non-overlapping location" - loc2 = MemoryLocation(offset=0, size=32) - loc3 = MemoryLocation(offset=32, size=32) + loc2 = MemoryLocationSegment(_offset=0, _size=32) + loc3 = MemoryLocationSegment(_offset=32, _size=32) assert alias.may_alias(loc2, loc3) == alias.may_alias( loc2, loc3 ), "may_alias should use may_alias for locations not in alias sets" - loc4 = MemoryLocation(offset=0, size=32) - loc5 = MemoryLocation(offset=0, size=32) - loc6 = MemoryLocation(offset=32, size=32) + loc4 = MemoryLocationSegment(_offset=0, _size=32) + loc5 = MemoryLocationSegment(_offset=0, _size=32) + loc6 = MemoryLocationSegment(_offset=32, _size=32) alias._analyze_mem_location(loc4) alias._analyze_mem_location(loc5) alias._analyze_mem_location(loc6) @@ -263,31 +263,31 @@ def test_may_alias_edge_cases2(): alias = MemoryAliasAnalysis(ac, fn) alias.analyze() - loc1 = MemoryLocation(offset=0, size=32) + loc1 = MemoryLocationSegment(_offset=0, _size=32) assert alias.may_alias( FULL_MEMORY_ACCESS, loc1 ), "FULL_MEMORY_ACCESS should alias with regular location" assert not alias.may_alias( - MemoryLocation.EMPTY, loc1 + MemoryLocationSegment.EMPTY, loc1 ), "EMPTY_MEMORY_ACCESS should not alias with regular location" - volatile_loc = MemoryLocation(offset=0, size=32, is_volatile=True) - overlapping_loc = MemoryLocation(offset=16, size=32) + volatile_loc = MemoryLocationSegment(_offset=0, _size=32, _is_volatile=True) + overlapping_loc = MemoryLocationSegment(_offset=16, _size=32) assert alias.may_alias( volatile_loc, overlapping_loc ), "Volatile location should alias with overlapping location" - loc2 = MemoryLocation(offset=0, size=64) - loc3 = MemoryLocation(offset=32, size=64) + loc2 = MemoryLocationSegment(_offset=0, _size=64) + loc3 = MemoryLocationSegment(_offset=32, _size=64) result = alias.may_alias(loc2, loc3) assert result == alias.may_alias( loc2, loc3 ), "may_alias should use may_alias for locations not in alias sets" - loc4 = MemoryLocation(offset=0, size=32) - loc5 = MemoryLocation(offset=0, size=32) - loc6 = MemoryLocation(offset=0, size=32) + loc4 = MemoryLocationSegment(_offset=0, _size=32) + loc5 = MemoryLocationSegment(_offset=0, _size=32) + loc6 = MemoryLocationSegment(_offset=0, _size=32) alias._analyze_mem_location(loc4) alias._analyze_mem_location(loc5) alias._analyze_mem_location(loc6) diff --git a/tests/unit/compiler/venom/test_mem_ssa.py b/tests/unit/compiler/venom/test_mem_ssa.py index 148d47bf73..0da5229202 100644 --- a/tests/unit/compiler/venom/test_mem_ssa.py +++ b/tests/unit/compiler/venom/test_mem_ssa.py @@ -13,7 +13,7 @@ ) from vyper.venom.basicblock import IRBasicBlock, IRLabel from vyper.venom.effects import Effects -from vyper.venom.memory_location import get_read_location, get_write_location +from vyper.venom.memory_location import MemoryLocationSegment, get_read_location, get_write_location @pytest.fixture @@ -327,16 +327,16 @@ def test_may_alias(dummy_mem_ssa): mem_ssa, _, _ = dummy_mem_ssa # Test non-overlapping memory locations - loc1 = MemoryLocation(offset=0, size=32) - loc2 = MemoryLocation(offset=32, size=32) + loc1 = MemoryLocationSegment(_offset=0, _size=32) + loc2 = MemoryLocationSegment(_offset=32, _size=32) assert not mem_ssa.memalias.may_alias(loc1, loc2), "Non-overlapping locations should not alias" # Test overlapping memory locations - loc3 = MemoryLocation(offset=0, size=16) - loc4 = MemoryLocation(offset=8, size=8) + loc3 = MemoryLocationSegment(_offset=0, _size=16) + loc4 = MemoryLocationSegment(_offset=8, _size=8) assert mem_ssa.memalias.may_alias(loc3, loc4), "Overlapping locations should alias" - full_loc = MemoryLocation(offset=0, size=None) + full_loc = MemoryLocationSegment(_offset=0, _size=None) assert mem_ssa.memalias.may_alias(full_loc, loc1), "should alias with any non-empty location" assert not mem_ssa.memalias.may_alias( full_loc, MemoryLocation.EMPTY @@ -352,7 +352,7 @@ def test_may_alias(dummy_mem_ssa): ), "EMPTY_MEMORY_ACCESS should not alias" # Test zero/negative size locations - zero_size_loc = MemoryLocation(offset=0, size=0) + zero_size_loc = MemoryLocationSegment(_offset=0, _size=0) assert not mem_ssa.memalias.may_alias( zero_size_loc, loc1 ), "Zero size location should not alias" @@ -361,19 +361,19 @@ def test_may_alias(dummy_mem_ssa): ), "Zero size locations should not alias with each other" # Test partial overlap - loc5 = MemoryLocation(offset=0, size=64) - loc6 = MemoryLocation(offset=32, size=32) + loc5 = MemoryLocationSegment(_offset=0, _size=64) + loc6 = MemoryLocationSegment(_offset=32, _size=32) assert mem_ssa.memalias.may_alias(loc5, loc6), "Partially overlapping locations should alias" assert mem_ssa.memalias.may_alias(loc6, loc5), "Partially overlapping locations should alias" # Test exact same location - loc7 = MemoryLocation(offset=0, size=64) - loc8 = MemoryLocation(offset=0, size=64) + loc7 = MemoryLocationSegment(_offset=0, _size=64) + loc8 = MemoryLocationSegment(_offset=0, _size=64) assert mem_ssa.memalias.may_alias(loc7, loc8), "Identical locations should alias" # Test adjacent but non-overlapping locations - loc9 = MemoryLocation(offset=0, size=64) - loc10 = MemoryLocation(offset=64, size=64) + loc9 = MemoryLocationSegment(_offset=0, _size=64) + loc10 = MemoryLocationSegment(_offset=64, _size=64) assert not mem_ssa.memalias.may_alias( loc9, loc10 ), "Adjacent but non-overlapping locations should not alias" @@ -564,8 +564,8 @@ def test_analyze_instruction_with_no_memory_ops(): assignment_inst = bb.instructions[0] # %1 = 42 # Verify that the instruction doesn't have memory operations - assert get_read_location(assignment_inst, MEMORY) is MemoryLocation.EMPTY - assert get_write_location(assignment_inst, MEMORY) is MemoryLocation.EMPTY + assert get_read_location(assignment_inst, MEMORY, {}) is MemoryLocation.EMPTY + assert get_write_location(assignment_inst, MEMORY, {}) is MemoryLocation.EMPTY assert mem_ssa.memalias.alias_sets is not None @@ -826,7 +826,7 @@ def test_get_reaching_def_with_phi(): # Create a new memory definition with the same location as the phi new_def = MemoryDef(mem_ssa.next_id, merge_block.instructions[0], MEMORY) mem_ssa.next_id += 1 - new_def.loc = MemoryLocation(offset=0, size=32) # Same location as the phi + new_def.loc = MemoryLocationSegment(_offset=0, _size=32) # Same location as the phi result = mem_ssa._get_reaching_def(new_def) assert result == phi @@ -846,7 +846,7 @@ def test_get_reaching_def_with_no_phi(): new_def = MemoryDef(mem_ssa.next_id, entry_block.instructions[0], MEMORY) mem_ssa.next_id += 1 - new_def.loc = MemoryLocation(offset=0, size=32) + new_def.loc = MemoryLocationSegment(_offset=0, _size=32) result = mem_ssa._get_reaching_def(new_def) assert result == mem_ssa.live_on_entry diff --git a/tests/unit/compiler/venom/test_memory_location.py b/tests/unit/compiler/venom/test_memory_location.py index 3c32161096..4623b46bad 100644 --- a/tests/unit/compiler/venom/test_memory_location.py +++ b/tests/unit/compiler/venom/test_memory_location.py @@ -1,14 +1,14 @@ -from vyper.venom.memory_location import MemoryLocation +from vyper.venom.memory_location import MemoryLocation, MemoryLocationSegment def test_completely_overlaps(): # Create memory locations with different offsets and sizes - loc1 = MemoryLocation(offset=0, size=32) - loc2 = MemoryLocation(offset=0, size=32) # Same as loc1 - loc3 = MemoryLocation(offset=0, size=64) # Larger than loc1 - loc4 = MemoryLocation(offset=16, size=16) # Inside loc1 - loc5 = MemoryLocation(offset=16, size=32) # Partially overlaps loc1 - loc6 = MemoryLocation(offset=32, size=32) # Adjacent to loc1 + loc1 = MemoryLocationSegment(_offset=0, _size=32) + loc2 = MemoryLocationSegment(_offset=0, _size=32) # Same as loc1 + loc3 = MemoryLocationSegment(_offset=0, _size=64) # Larger than loc1 + loc4 = MemoryLocationSegment(_offset=16, _size=16) # Inside loc1 + loc5 = MemoryLocationSegment(_offset=16, _size=32) # Partially overlaps loc1 + loc6 = MemoryLocationSegment(_offset=32, _size=32) # Adjacent to loc1 assert loc1.completely_contains(loc1) assert loc1.completely_contains(loc2) @@ -21,7 +21,7 @@ def test_completely_overlaps(): assert not loc1.completely_contains(loc6) # Test with EMPTY and FULL memory access - full_loc = MemoryLocation(offset=0, size=None) + full_loc = MemoryLocationSegment(_offset=0, _size=None) assert not MemoryLocation.EMPTY.completely_contains(loc1) assert loc1.completely_contains(MemoryLocation.EMPTY) assert not full_loc.completely_contains(loc1) diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 1b166bda77..556cff716f 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -6,14 +6,14 @@ from vyper.codegen.ir_node import IRnode from vyper.compiler.settings import OptimizationLevel, Settings from vyper.evm.address_space import MEMORY, STORAGE, TRANSIENT -from vyper.exceptions import CompilerPanic from vyper.ir.compile_ir import AssemblyInstruction -from vyper.venom.analysis import MemSSA +from vyper.venom.analysis import FCGAnalysis from vyper.venom.analysis.analysis import IRAnalysesCache -from vyper.venom.basicblock import IRLabel, IRLiteral +from vyper.venom.basicblock import IRAbstractMemLoc, IRLabel, IRLiteral from vyper.venom.context import IRContext from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom +from vyper.venom.memory_location import fix_mem_loc from vyper.venom.passes import ( CSE, SCCP, @@ -21,7 +21,9 @@ AssignElimination, BranchOptimizationPass, CFGNormalization, + ConcretizeMemLocPass, DFTPass, + FixCalloca, FloatAllocas, FunctionInlinerPass, LoadElimination, @@ -92,6 +94,20 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel, ac: IRAnalysesCache DeadStoreElimination(ac, fn).run_pass(addr_space=TRANSIENT) LowerDloadPass(ac, fn).run_pass() + PhiEliminationPass(ac, fn).run_pass() + AssignElimination(ac, fn).run_pass() + ConcretizeMemLocPass(ac, fn).run_pass() + SCCP(ac, fn).run_pass() + AssignElimination(ac, fn).run_pass() + DeadStoreElimination(ac, fn).run_pass(addr_space=MEMORY) + LoadElimination(ac, fn).run_pass() + PhiEliminationPass(ac, fn).run_pass() + AssignElimination(ac, fn).run_pass() + SCCP(ac, fn).run_pass() + + SimplifyCFGPass(ac, fn).run_pass() + MemMergePass(ac, fn).run_pass() + RemoveUnusedVariablesPass(ac, fn).run_pass() BranchOptimizationPass(ac, fn).run_pass() AlgebraicOptimizationPass(ac, fn).run_pass() @@ -116,6 +132,7 @@ def _run_passes(fn: IRFunction, optimize: OptimizationLevel, ac: IRAnalysesCache def _run_global_passes(ctx: IRContext, optimize: OptimizationLevel, ir_analyses: dict) -> None: + FixCalloca(ir_analyses, ctx).run_pass() FunctionInlinerPass(ir_analyses, ctx, optimize).run_pass() @@ -130,8 +147,35 @@ def run_passes_on(ctx: IRContext, optimize: OptimizationLevel) -> None: for fn in ctx.functions.values(): ir_analyses[fn] = IRAnalysesCache(fn) - for fn in ctx.functions.values(): - _run_passes(fn, optimize, ir_analyses[fn]) + assert ctx.entry_function is not None + fcg = ir_analyses[ctx.entry_function].force_analysis(FCGAnalysis) + + _run_fn_passes(ctx, fcg, ctx.entry_function, optimize, ir_analyses) + + +def _run_fn_passes( + ctx: IRContext, fcg: FCGAnalysis, fn: IRFunction, optimize: OptimizationLevel, ir_analyses: dict +): + visited: set[IRFunction] = set() + assert ctx.entry_function is not None + _run_fn_passes_r(ctx, fcg, ctx.entry_function, optimize, ir_analyses, visited) + + +def _run_fn_passes_r( + ctx: IRContext, + fcg: FCGAnalysis, + fn: IRFunction, + optimize: OptimizationLevel, + ir_analyses: dict, + visited: set, +): + if fn in visited: + return + visited.add(fn) + for next_fn in fcg.get_callees(fn): + _run_fn_passes_r(ctx, fcg, next_fn, optimize, ir_analyses, visited) + + _run_passes(fn, optimize, ir_analyses[fn]) def generate_venom( @@ -145,6 +189,14 @@ def generate_venom( starting_symbols = {k: IRLiteral(v) for k, v in constants.items()} ctx = ir_node_to_venom(ir, starting_symbols) + # these mem location are used sha3_64 instruction + # with concrete value so I need to allocate it here + ctx.mem_allocator.allocate(IRAbstractMemLoc.FREE_VAR1) + ctx.mem_allocator.allocate(IRAbstractMemLoc.FREE_VAR2) + + for fn in ctx.functions.values(): + fix_mem_loc(fn) + data_sections = data_sections or {} for section_name, data in data_sections.items(): ctx.append_data_section(IRLabel(section_name)) diff --git a/vyper/venom/analysis/mem_alias.py b/vyper/venom/analysis/mem_alias.py index 41a3a03f75..b83b84a971 100644 --- a/vyper/venom/analysis/mem_alias.py +++ b/vyper/venom/analysis/mem_alias.py @@ -1,10 +1,9 @@ -import dataclasses as dc from typing import Optional from vyper.evm.address_space import MEMORY, STORAGE, TRANSIENT, AddrSpace from vyper.utils import OrderedSet from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, IRAnalysis -from vyper.venom.basicblock import IRInstruction +from vyper.venom.basicblock import IRAbstractMemLoc, IRInstruction, IRVariable from vyper.venom.memory_location import MemoryLocation, get_read_location, get_write_location @@ -23,21 +22,47 @@ def analyze(self): # Map from memory locations to sets of potentially aliasing locations self.alias_sets: dict[MemoryLocation, OrderedSet[MemoryLocation]] = {} + self.var_base_pointers: dict[IRVariable, IRAbstractMemLoc] = {} + + for bb in self.function.get_basic_blocks(): + for inst in bb.instructions: + if inst.opcode != "gep": + continue + place = self._follow_gep(inst) + assert inst.output is not None + self.var_base_pointers[inst.output] = place # Analyze all memory operations for bb in self.function.get_basic_blocks(): for inst in bb.instructions: self._analyze_instruction(inst) + def _follow_gep(self, inst: IRInstruction): + assert inst.opcode == "gep" + place = inst.operands[0] + if isinstance(place, IRVariable): + next_inst = self.dfg.get_producing_instruction(place) + assert next_inst is not None + place = self._follow_gep(next_inst) + + assert isinstance(place, IRAbstractMemLoc) + return place + + def _get_read_location(self, inst: IRInstruction, addr_space: AddrSpace) -> MemoryLocation: + return get_read_location(inst, addr_space, self.var_base_pointers) + + def _get_write_location(self, inst: IRInstruction, addr_space: AddrSpace) -> MemoryLocation: + return get_write_location(inst, addr_space, self.var_base_pointers) + def _analyze_instruction(self, inst: IRInstruction): """Analyze a memory instruction to determine aliasing""" loc: Optional[MemoryLocation] = None - loc = get_read_location(inst, self.addr_space) + loc = get_read_location(inst, self.addr_space, self.var_base_pointers) if loc is not None: self._analyze_mem_location(loc) - loc = get_write_location(inst, self.addr_space) + loc = get_write_location(inst, self.addr_space, self.var_base_pointers) if loc is not None: self._analyze_mem_location(loc) @@ -72,7 +97,7 @@ def may_alias(self, loc1: MemoryLocation, loc2: MemoryLocation) -> bool: return result def mark_volatile(self, loc: MemoryLocation) -> MemoryLocation: - volatile_loc = dc.replace(loc, is_volatile=True) + volatile_loc = loc.create_volatile() if loc in self.alias_sets: self.alias_sets[volatile_loc] = OrderedSet([volatile_loc]) diff --git a/vyper/venom/analysis/mem_ssa.py b/vyper/venom/analysis/mem_ssa.py index c1db8789ba..babe39c677 100644 --- a/vyper/venom/analysis/mem_ssa.py +++ b/vyper/venom/analysis/mem_ssa.py @@ -1,5 +1,4 @@ import contextlib -import dataclasses as dc from typing import Iterable, Optional from vyper.evm.address_space import MEMORY, STORAGE, TRANSIENT, AddrSpace @@ -12,7 +11,7 @@ TransientAliasAnalysis, ) from vyper.venom.basicblock import IRBasicBlock, IRInstruction, ir_printer -from vyper.venom.memory_location import MemoryLocation, get_read_location, get_write_location +from vyper.venom.memory_location import MemoryLocation class MemoryAccess: @@ -72,10 +71,10 @@ class LiveOnEntry(MemoryAccess): class MemoryDef(MemoryAccess): """Represents a definition of memory state""" - def __init__(self, id: int, store_inst: IRInstruction, addr_space: AddrSpace): + def __init__(self, id: int, store_inst: IRInstruction, loc: MemoryLocation): super().__init__(id) self.store_inst = store_inst - self.loc = get_write_location(store_inst, addr_space) + self.loc = loc @property def inst(self): @@ -85,10 +84,10 @@ def inst(self): class MemoryUse(MemoryAccess): """Represents a use of memory state""" - def __init__(self, id: int, load_inst: IRInstruction, addr_space: AddrSpace): + def __init__(self, id: int, load_inst: IRInstruction, loc: MemoryLocation): super().__init__(id) self.load_inst = load_inst - self.loc = get_read_location(load_inst, addr_space) + self.loc = loc @property def inst(self): @@ -154,6 +153,7 @@ def analyze(self): # Clean up unnecessary phi nodes self._remove_redundant_phis() + self.analyses_cache.invalidate_analysis(self.mem_alias_type) def mark_location_volatile(self, loc: MemoryLocation) -> MemoryLocation: volatile_loc = self.memalias.mark_volatile(loc) @@ -161,7 +161,7 @@ def mark_location_volatile(self, loc: MemoryLocation) -> MemoryLocation: for bb in self.memory_defs: for mem_def in self.memory_defs[bb]: if self.memalias.may_alias(mem_def.loc, loc): - mem_def.loc = dc.replace(mem_def.loc, is_volatile=True) + mem_def.loc = mem_def.loc.create_volatile() return volatile_loc @@ -194,15 +194,19 @@ def _process_block_definitions(self, block: IRBasicBlock): """Process memory definitions and uses in a basic block""" for inst in block.instructions: # Check for memory reads - if get_read_location(inst, self.addr_space) != MemoryLocation.EMPTY: - mem_use = MemoryUse(self.next_id, inst, self.addr_space) + if ( + loc := self.memalias._get_read_location(inst, self.addr_space) + ) != MemoryLocation.EMPTY: + mem_use = MemoryUse(self.next_id, inst, loc) self.next_id += 1 self.memory_uses.setdefault(block, []).append(mem_use) self.inst_to_use[inst] = mem_use # Check for memory writes - if get_write_location(inst, self.addr_space) != MemoryLocation.EMPTY: - mem_def = MemoryDef(self.next_id, inst, self.addr_space) + if ( + loc := self.memalias._get_write_location(inst, self.addr_space) + ) != MemoryLocation.EMPTY: + mem_def = MemoryDef(self.next_id, inst, loc) self.next_id += 1 self.memory_defs.setdefault(block, []).append(mem_def) self.inst_to_def[inst] = mem_def diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 1b72094dd7..ec74e76afc 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -3,7 +3,7 @@ import json import re from contextvars import ContextVar -from typing import TYPE_CHECKING, Any, Iterator, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Iterator, Optional, Union import vyper.venom.effects as effects from vyper.codegen.ir_node import IRnode @@ -34,6 +34,7 @@ "extcodecopy", "returndatacopy", "codecopy", + "codecopyruntime", "dloadbytes", "return", "ret", @@ -62,6 +63,7 @@ "mcopy", "returndatacopy", "codecopy", + "codecopyruntime", "extcodecopy", "return", "ret", @@ -173,6 +175,48 @@ def __repr__(self) -> str: return f"0x{self.value:x}" +class IRAbstractMemLoc(IROperand): + _id: int + size: int + offset: int + + _curr_id: ClassVar[int] + FREE_VAR1: ClassVar["IRAbstractMemLoc"] + FREE_VAR2: ClassVar["IRAbstractMemLoc"] + + def __init__(self, size: int, offset: int = 0, force_id=None): + if force_id is None: + self._id = IRAbstractMemLoc._curr_id + IRAbstractMemLoc._curr_id += 1 + else: + self._id = force_id + super().__init__(self._id) + self.size = size + self.offset = offset + + def __hash__(self) -> int: + return self._id ^ self.offset + + def __eq__(self, other) -> bool: + if type(self) is not type(other): + return False + return self._id == other._id and self.offset == other.offset + + def __repr__(self) -> str: + return f"[{self._id},{self.size} + {self.offset}]" + + def no_offset(self) -> IRAbstractMemLoc: + return IRAbstractMemLoc(self.size, force_id=self._id) + + def with_offset(self, offset: int) -> IRAbstractMemLoc: + return IRAbstractMemLoc(self.size, offset=offset, force_id=self._id) + + +IRAbstractMemLoc._curr_id = 0 +IRAbstractMemLoc.FREE_VAR1 = IRAbstractMemLoc(32) +IRAbstractMemLoc.FREE_VAR2 = IRAbstractMemLoc(32) + + class IRVariable(IROperand): """ IRVariable represents a variable in IR. A variable is a string that starts with a %. diff --git a/vyper/venom/context.py b/vyper/venom/context.py index f50dc1220f..a989d0e93e 100644 --- a/vyper/venom/context.py +++ b/vyper/venom/context.py @@ -4,6 +4,7 @@ from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable from vyper.venom.function import IRFunction +from vyper.venom.memory_allocator import MemoryAllocator @dataclass @@ -37,6 +38,7 @@ class IRContext: data_segment: list[DataSection] last_label: int last_variable: int + mem_allocator: MemoryAllocator def __init__(self) -> None: self.functions = {} @@ -46,6 +48,7 @@ def __init__(self) -> None: self.last_label = 0 self.last_variable = 0 + self.mem_allocator = MemoryAllocator() def get_basic_blocks(self) -> Iterator[IRBasicBlock]: for fn in self.functions.values(): diff --git a/vyper/venom/effects.py b/vyper/venom/effects.py index e03d370269..4039ecae51 100644 --- a/vyper/venom/effects.py +++ b/vyper/venom/effects.py @@ -45,6 +45,7 @@ class Effects(Flag): "returndatacopy": MEMORY, "calldatacopy": MEMORY, "codecopy": MEMORY, + "codecopyruntime": MEMORY, "extcodecopy": MEMORY, "mcopy": MEMORY, } diff --git a/vyper/venom/function.py b/vyper/venom/function.py index c57229eabc..62a452d976 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -6,8 +6,7 @@ from typing import TYPE_CHECKING, Iterator, Optional from vyper.codegen.ir_node import IRnode -from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable -from vyper.venom.memory_location import MemoryLocation +from vyper.venom.basicblock import IRAbstractMemLoc, IRBasicBlock, IRLabel, IRVariable if TYPE_CHECKING: from vyper.venom.context import IRContext @@ -33,9 +32,9 @@ class IRFunction: name: IRLabel # symbol name ctx: IRContext args: list + allocated_args: dict[int, IRAbstractMemLoc] last_variable: int _basic_block_dict: dict[str, IRBasicBlock] - _volatile_memory: list[MemoryLocation] # Used during code generation _ast_source_stack: list[IRnode] @@ -45,8 +44,8 @@ def __init__(self, name: IRLabel, ctx: IRContext = None): self.ctx = ctx # type: ignore self.name = name self.args = [] + self.allocated_args = dict() self._basic_block_dict = {} - self._volatile_memory = [] self.last_variable = 0 @@ -165,10 +164,6 @@ def copy(self): new_bb = bb.copy() new.append_basic_block(new_bb) - # Copy volatile memory locations - for mem in self._volatile_memory: - new.add_volatile_memory(mem.offset, mem.size) - return new def as_graph(self, only_subgraph=False) -> str: @@ -216,18 +211,3 @@ def __repr__(self) -> str: ret = ret.strip() + "\n}" ret += f" ; close function {self.name}" return ret - - def add_volatile_memory(self, offset: int, size: int) -> MemoryLocation: - """ - Add a volatile memory location with the given offset and size. - Returns the created MemoryLocation object. - """ - volatile_mem = MemoryLocation(offset=offset, size=size) - self._volatile_memory.append(volatile_mem) - return volatile_mem - - def get_all_volatile_memory(self) -> list[MemoryLocation]: - """ - Return all volatile memory locations. - """ - return self._volatile_memory diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 3e3842540b..97410faf26 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -10,6 +10,7 @@ from vyper.evm.opcodes import get_opcodes from vyper.ir.compile_ir import _runtime_code_offsets from vyper.venom.basicblock import ( + IRAbstractMemLoc, IRBasicBlock, IRInstruction, IRLabel, @@ -305,7 +306,9 @@ def _handle_internal_func( # buffer size of 32 bytes. # TODO: we don't need to use scratch space once the legacy optimizer # is disabled. - buf = bb.append_instruction("alloca", 0, 32, get_scratch_alloca_id()) + buf = bb.append_instruction( + "alloca", IRAbstractMemLoc.FREE_VAR1, get_scratch_alloca_id() + ) else: buf = bb.append_instruction("param") bb.instructions[-1].annotation = "return_buffer" @@ -422,11 +425,13 @@ def _convert_ir_bb(fn, ir, symbols): bb = fn.get_basic_block() + mem_start_var = bb.append_instruction("mem_deploy_start", mem_deploy_start) + bb.append_instruction( - "codecopy", runtime_codesize, IRLabel("runtime_begin"), mem_deploy_start + "codecopyruntime", runtime_codesize, IRLabel("runtime_begin"), mem_start_var ) amount_to_return = bb.append_instruction("add", runtime_codesize, immutables_len) - bb.append_instruction("return", amount_to_return, mem_deploy_start) + bb.append_instruction("return", amount_to_return, mem_start_var) return None elif ir.value == "seq": if len(ir.args) == 0: @@ -690,17 +695,19 @@ def emit_body_blocks(): if ir.value.startswith("$alloca"): alloca = ir.passthrough_metadata["alloca"] if alloca._id not in _alloca_table: - ptr = fn.get_basic_block().append_instruction( - "alloca", alloca.offset, alloca.size, alloca._id - ) + mem_loc_op = IRAbstractMemLoc(alloca.size) + ptr = fn.get_basic_block().append_instruction("alloca", mem_loc_op, alloca._id) _alloca_table[alloca._id] = ptr return _alloca_table[alloca._id] elif ir.value.startswith("$palloca"): + assert isinstance(fn, IRFunction) alloca = ir.passthrough_metadata["alloca"] if alloca._id not in _alloca_table: + mem_loc_op = IRAbstractMemLoc(alloca.size) + fn.allocated_args[alloca._id] = mem_loc_op bb = fn.get_basic_block() - ptr = bb.append_instruction("palloca", alloca.offset, alloca.size, alloca._id) + ptr = bb.append_instruction("palloca", mem_loc_op, alloca._id) bb.instructions[-1].annotation = f"{alloca.name} (memory)" if ENABLE_NEW_CALL_CONV and _pass_via_stack(_current_func_t)[alloca.name]: param = fn.get_param_by_id(alloca._id) @@ -717,11 +724,13 @@ def emit_body_blocks(): callsite_func = ir.passthrough_metadata["callsite_func"] if ENABLE_NEW_CALL_CONV and _pass_via_stack(callsite_func)[alloca.name]: - ptr = bb.append_instruction("alloca", alloca.offset, alloca.size, alloca._id) + ptr = bb.append_instruction("alloca", IRAbstractMemLoc(alloca.size), alloca._id) else: # if we use alloca, mstores might get removed. convert # to calloca until memory analysis is more sound. - ptr = bb.append_instruction("calloca", alloca.offset, alloca.size, alloca._id) + ptr = bb.append_instruction( + "calloca", alloca.size, alloca._id, IRLabel(alloca._callsite) + ) _alloca_table[alloca._id] = ptr ret = _alloca_table[alloca._id] diff --git a/vyper/venom/memory_allocator.py b/vyper/venom/memory_allocator.py new file mode 100644 index 0000000000..27a7adf7a5 --- /dev/null +++ b/vyper/venom/memory_allocator.py @@ -0,0 +1,27 @@ +from typing import Any + +from vyper.utils import OrderedSet +from vyper.venom.basicblock import IRAbstractMemLoc + + +class MemoryAllocator: + allocated: dict[int, tuple[int, int]] + curr: int + mems_used: dict[Any, OrderedSet[IRAbstractMemLoc]] + + def __init__(self): + self.curr = 0 + self.allocated = dict() + self.mems_used = dict() + + def allocate(self, mem_loc: IRAbstractMemLoc) -> int: + ptr = self.curr + self.curr += mem_loc.size + self.allocated[mem_loc._id] = (ptr, mem_loc.size) + return ptr + + def start_fn_allocation(self): + self.curr = 64 + + def end_fn_allocation(self, mems: list[IRAbstractMemLoc], fn): + self.mems_used[fn] = OrderedSet(mems) diff --git a/vyper/venom/memory_location.py b/vyper/venom/memory_location.py index ec2a2f9da8..e7639bcf45 100644 --- a/vyper/venom/memory_location.py +++ b/vyper/venom/memory_location.py @@ -1,26 +1,169 @@ from __future__ import annotations +import dataclasses as dc from dataclasses import dataclass from typing import ClassVar from vyper.evm.address_space import MEMORY, STORAGE, TRANSIENT, AddrSpace from vyper.exceptions import CompilerPanic -from vyper.venom.basicblock import IRLiteral, IROperand, IRVariable +from vyper.utils import MemoryPositions +from vyper.venom.basicblock import IRAbstractMemLoc, IRInstruction, IRLiteral, IROperand, IRVariable +from vyper.venom.function import IRFunction -@dataclass(frozen=True) class MemoryLocation: + # Initialize after class definition + EMPTY: ClassVar[MemoryLocation] + UNDEFINED: ClassVar[MemoryLocation] + + @classmethod + def from_operands( + cls, offset: IROperand | int, size: IROperand | int, var_base_pointers: dict + ) -> MemoryLocation: + if isinstance(size, IRLiteral): + _size = size.value + elif isinstance(size, IRVariable): + _size = None + elif isinstance(size, int): + _size = size + else: # pragma: nocover + raise CompilerPanic(f"invalid size: {size} ({type(size)})") + + if isinstance(offset, IRLiteral): + return MemoryLocationSegment(offset.value, _size) + elif isinstance(offset, IRVariable): + op = var_base_pointers.get(offset, None) + if op is None: + return MemoryLocationSegment(_offset=None, _size=_size) + else: + segment = MemoryLocationSegment(_offset=None, _size=_size) + return MemoryLocationAbstract(op=op, segment=segment) + elif isinstance(offset, IRAbstractMemLoc): + op = offset + segment = MemoryLocationSegment(_offset=op.offset, _size=_size) + return MemoryLocationAbstract(op=op, segment=segment) + else: # pragma: nocover + raise CompilerPanic(f"invalid offset: {offset} ({type(offset)})") + + @property + def offset(self) -> int | None: # pragma: nocover + raise NotImplementedError + + @property + def size(self) -> int | None: # pragma: nocover + raise NotImplementedError + + @property + def is_offset_fixed(self) -> bool: # pragma: nocover + raise NotImplementedError + + @property + def is_size_fixed(self) -> bool: # pragma: nocover + raise NotImplementedError + + @property + def is_fixed(self) -> bool: # pragma: nocover + raise NotImplementedError + + @property + def is_volatile(self) -> bool: # pragma: nocover + raise NotImplementedError + + @staticmethod + def may_overlap(loc1: MemoryLocation, loc2: MemoryLocation) -> bool: + if loc1.size == 0 or loc2.size == 0: + return False + if not loc1.is_offset_fixed or not loc2.is_offset_fixed: + return True + if loc1 is MemoryLocation.UNDEFINED or loc2 is MemoryLocation.UNDEFINED: + return True + if type(loc1) is not type(loc2): + return False + if isinstance(loc1, MemoryLocationSegment): + assert isinstance(loc2, MemoryLocationSegment) + return MemoryLocationSegment.may_overlap_concrete(loc1, loc2) + if isinstance(loc1, MemoryLocationAbstract): + assert isinstance(loc2, MemoryLocationAbstract) + return MemoryLocationAbstract.may_overlap_abstract(loc1, loc2) + return False + + def completely_contains(self, other: MemoryLocation) -> bool: # pragma: nocover + raise NotImplementedError + + def create_volatile(self) -> MemoryLocation: # pragma: nocover + raise NotImplementedError + + +@dataclass(frozen=True) +class MemoryLocationAbstract(MemoryLocation): + op: IRAbstractMemLoc + segment: MemoryLocationSegment + + @property + def offset(self): + return self.segment.offset + + @property + def size(self): + return self.segment.size + + @property + def is_offset_fixed(self) -> bool: + return True + + @property + def is_size_fixed(self) -> bool: + return True + + @property + def is_fixed(self) -> bool: + return True + + @property + def is_volatile(self) -> bool: + return self.segment.is_volatile + + def create_volatile(self) -> MemoryLocationAbstract: + return dc.replace(self, segment=self.segment.create_volatile()) + + @staticmethod + def may_overlap_abstract(loc1: MemoryLocationAbstract, loc2: MemoryLocationAbstract) -> bool: + if loc1.op._id == loc2.op._id: + return MemoryLocationSegment.may_overlap_concrete(loc1.segment, loc2.segment) + else: + return False + + def completely_contains(self, other: MemoryLocation) -> bool: + if other == MemoryLocation.UNDEFINED: + return False + if not isinstance(other, MemoryLocationAbstract): + return False + if self.size is None: + return False + if other.size == 0: + return True + if self.op._id == other.op._id: + return self.segment.completely_contains(other.segment) + return False + + +@dataclass(frozen=True) +class MemoryLocationSegment(MemoryLocation): """Represents a memory location that can be analyzed for aliasing""" - offset: int | None = None - size: int | None = None + _offset: int | None = None + _size: int | None = None + _is_volatile: bool = False # Locations that should be considered volatile. Example usages of this would # be locations that are accessed outside of the current function. - is_volatile: bool = False - # Initialize after class definition - EMPTY: ClassVar[MemoryLocation] - UNDEFINED: ClassVar[MemoryLocation] + @property + def offset(self): + return self._offset + + @property + def size(self): + return self._size @property def is_offset_fixed(self) -> bool: @@ -34,29 +177,12 @@ def is_size_fixed(self) -> bool: def is_fixed(self) -> bool: return self.is_offset_fixed and self.is_size_fixed - @classmethod - def from_operands( - cls, offset: IROperand | int, size: IROperand | int, /, is_volatile: bool = False - ) -> MemoryLocation: - if isinstance(offset, IRLiteral): - _offset = offset.value - elif isinstance(offset, IRVariable): - _offset = None - elif isinstance(offset, int): - _offset = offset - else: # pragma: nocover - raise CompilerPanic(f"invalid offset: {offset} ({type(offset)})") - - if isinstance(size, IRLiteral): - _size = size.value - elif isinstance(size, IRVariable): - _size = None - elif isinstance(size, int): - _size = size - else: # pragma: nocover - raise CompilerPanic(f"invalid size: {size} ({type(size)})") + @property + def is_volatile(self) -> bool: + return self._is_volatile - return cls(_offset, _size, is_volatile) + def create_volatile(self) -> MemoryLocationSegment: + return dc.replace(self, _is_volatile=True) # similar code to memmerging._Interval, but different data structure def completely_contains(self, other: MemoryLocation) -> bool: @@ -72,6 +198,9 @@ def completely_contains(self, other: MemoryLocation) -> bool: if not other.is_offset_fixed or not other.is_size_fixed: return False + if not isinstance(other, MemoryLocationSegment): + return False + # Both are known assert self.offset is not None and self.size is not None assert other.offset is not None and other.size is not None @@ -81,7 +210,7 @@ def completely_contains(self, other: MemoryLocation) -> bool: return start1 <= start2 and end1 >= end2 @staticmethod - def may_overlap(loc1: MemoryLocation, loc2: MemoryLocation) -> bool: + def may_overlap_concrete(loc1: MemoryLocationSegment, loc2: MemoryLocationSegment) -> bool: """ Determine if two memory locations may overlap """ @@ -124,107 +253,107 @@ def may_overlap(loc1: MemoryLocation, loc2: MemoryLocation) -> bool: return True -MemoryLocation.EMPTY = MemoryLocation(offset=0, size=0) -MemoryLocation.UNDEFINED = MemoryLocation(offset=None, size=None) +MemoryLocation.EMPTY = MemoryLocationSegment(_offset=0, _size=0) +MemoryLocation.UNDEFINED = MemoryLocationSegment(_offset=None, _size=None) -def get_write_location(inst, addr_space: AddrSpace) -> MemoryLocation: +def get_write_location(inst, addr_space: AddrSpace, var_base_pointers: dict) -> MemoryLocation: """Extract memory location info from an instruction""" if addr_space == MEMORY: - return _get_memory_write_location(inst) + return _get_memory_write_location(inst, var_base_pointers) elif addr_space in (STORAGE, TRANSIENT): - return _get_storage_write_location(inst, addr_space) + return _get_storage_write_location(inst, addr_space, var_base_pointers) else: # pragma: nocover raise CompilerPanic(f"Invalid location type: {addr_space}") -def get_read_location(inst, addr_space: AddrSpace) -> MemoryLocation: +def get_read_location(inst, addr_space: AddrSpace, var_base_pointers) -> MemoryLocation: """Extract memory location info from an instruction""" if addr_space == MEMORY: - return _get_memory_read_location(inst) + return _get_memory_read_location(inst, var_base_pointers) elif addr_space in (STORAGE, TRANSIENT): - return _get_storage_read_location(inst, addr_space) + return _get_storage_read_location(inst, addr_space, var_base_pointers) else: # pragma: nocover raise CompilerPanic(f"Invalid location type: {addr_space}") -def _get_memory_write_location(inst) -> MemoryLocation: +def _get_memory_write_location(inst, var_base_pointers: dict) -> MemoryLocation: opcode = inst.opcode if opcode == "mstore": dst = inst.operands[1] - return MemoryLocation.from_operands(dst, MEMORY.word_scale) + return MemoryLocation.from_operands(dst, MEMORY.word_scale, var_base_pointers) elif opcode == "mload": return MemoryLocation.EMPTY elif opcode in ("mcopy", "calldatacopy", "dloadbytes", "codecopy", "returndatacopy"): size, _, dst = inst.operands - return MemoryLocation.from_operands(dst, size) + return MemoryLocation.from_operands(dst, size, var_base_pointers) elif opcode == "dload": - return MemoryLocation(offset=0, size=32) + return MemoryLocationSegment(_offset=0, _size=32) elif opcode == "sha3_64": - return MemoryLocation(offset=0, size=64) + return MemoryLocationSegment(_offset=0, _size=64) elif opcode == "invoke": - return MemoryLocation(offset=0, size=None) + return MemoryLocation.UNDEFINED elif opcode == "call": size, dst, _, _, _, _, _ = inst.operands - return MemoryLocation.from_operands(dst, size) + return MemoryLocation.from_operands(dst, size, var_base_pointers) elif opcode in ("delegatecall", "staticcall"): size, dst, _, _, _, _ = inst.operands - return MemoryLocation.from_operands(dst, size) + return MemoryLocation.from_operands(dst, size, var_base_pointers) elif opcode == "extcodecopy": size, _, dst, _ = inst.operands - return MemoryLocation.from_operands(dst, size) + return MemoryLocation.from_operands(dst, size, var_base_pointers) - return MemoryLocation.EMPTY + return MemoryLocationSegment.EMPTY -def _get_memory_read_location(inst) -> MemoryLocation: +def _get_memory_read_location(inst, var_base_pointers) -> MemoryLocation: opcode = inst.opcode if opcode == "mstore": - return MemoryLocation.EMPTY + return MemoryLocationSegment.EMPTY elif opcode == "mload": - return MemoryLocation.from_operands(inst.operands[0], MEMORY.word_scale) + return MemoryLocation.from_operands(inst.operands[0], MEMORY.word_scale, var_base_pointers) elif opcode == "mcopy": size, src, _ = inst.operands - return MemoryLocation.from_operands(src, size) + return MemoryLocation.from_operands(src, size, var_base_pointers) elif opcode == "dload": - return MemoryLocation(offset=0, size=32) + return MemoryLocationSegment(_offset=0, _size=32) elif opcode == "invoke": - return MemoryLocation(offset=0, size=None) + return MemoryLocation.UNDEFINED elif opcode == "call": _, _, size, dst, _, _, _ = inst.operands - return MemoryLocation.from_operands(dst, size) + return MemoryLocation.from_operands(dst, size, var_base_pointers) elif opcode in ("delegatecall", "staticcall"): _, _, size, dst, _, _ = inst.operands - return MemoryLocation.from_operands(dst, size) + return MemoryLocation.from_operands(dst, size, var_base_pointers) elif opcode == "return": size, src = inst.operands - return MemoryLocation.from_operands(src, size) + return MemoryLocation.from_operands(src, size, var_base_pointers) elif opcode == "create": size, src, _value = inst.operands - return MemoryLocation.from_operands(src, size) + return MemoryLocation.from_operands(src, size, var_base_pointers) elif opcode == "create2": _salt, size, src, _value = inst.operands - return MemoryLocation.from_operands(src, size) + return MemoryLocation.from_operands(src, size, var_base_pointers) elif opcode == "sha3": size, offset = inst.operands - return MemoryLocation.from_operands(offset, size) + return MemoryLocation.from_operands(offset, size, var_base_pointers) elif opcode == "sha3_64": - return MemoryLocation(offset=0, size=64) + return MemoryLocationSegment(_offset=0, _size=64) elif opcode == "log": size, src = inst.operands[-2:] - return MemoryLocation.from_operands(src, size) + return MemoryLocation.from_operands(src, size, var_base_pointers) elif opcode == "revert": size, src = inst.operands - return MemoryLocation.from_operands(src, size) + return MemoryLocation.from_operands(src, size, var_base_pointers) - return MemoryLocation.EMPTY + return MemoryLocationSegment.EMPTY -def _get_storage_write_location(inst, addr_space: AddrSpace) -> MemoryLocation: +def _get_storage_write_location(inst, addr_space: AddrSpace, var_base_pointers) -> MemoryLocation: opcode = inst.opcode if opcode == addr_space.store_op: dst = inst.operands[1] - return MemoryLocation.from_operands(dst, addr_space.word_scale) + return MemoryLocation.from_operands(dst, addr_space.word_scale, var_base_pointers) elif opcode == addr_space.load_op: return MemoryLocation.EMPTY elif opcode in ("call", "delegatecall", "staticcall"): @@ -237,12 +366,14 @@ def _get_storage_write_location(inst, addr_space: AddrSpace) -> MemoryLocation: return MemoryLocation.EMPTY -def _get_storage_read_location(inst, addr_space: AddrSpace) -> MemoryLocation: +def _get_storage_read_location(inst, addr_space: AddrSpace, var_base_pointers) -> MemoryLocation: opcode = inst.opcode if opcode == addr_space.store_op: return MemoryLocation.EMPTY elif opcode == addr_space.load_op: - return MemoryLocation.from_operands(inst.operands[0], addr_space.word_scale) + return MemoryLocation.from_operands( + inst.operands[0], addr_space.word_scale, var_base_pointers + ) elif opcode in ("call", "delegatecall", "staticcall"): return MemoryLocation.UNDEFINED elif opcode == "invoke": @@ -264,3 +395,191 @@ def _get_storage_read_location(inst, addr_space: AddrSpace) -> MemoryLocation: return MemoryLocation.UNDEFINED return MemoryLocation.EMPTY + + +def in_free_var(var, offset): + return offset >= var and offset < (var + 32) + + +def fix_mem_loc(function: IRFunction): + for bb in function.get_basic_blocks(): + for inst in bb.instructions: + if inst.opcode == "codecopyruntime": + continue + write_op = get_memory_write_op(inst) + read_op = get_memory_read_op(inst) + if write_op is not None: + size = get_write_size(inst) + if size is None or not isinstance(write_op.value, int): + continue + + if in_free_var(MemoryPositions.FREE_VAR_SPACE, write_op.value): + offset = write_op.value - MemoryPositions.FREE_VAR_SPACE + _update_write_location(inst, IRAbstractMemLoc.FREE_VAR1.with_offset(offset)) + elif in_free_var(MemoryPositions.FREE_VAR_SPACE2, write_op.value): + offset = write_op.value - MemoryPositions.FREE_VAR_SPACE2 + _update_write_location(inst, IRAbstractMemLoc.FREE_VAR2.with_offset(offset)) + if read_op is not None: + size = _get_read_size(inst) + if size is None or not isinstance(read_op.value, int): + continue + + if in_free_var(MemoryPositions.FREE_VAR_SPACE, read_op.value): + offset = read_op.value - MemoryPositions.FREE_VAR_SPACE + _update_read_location(inst, IRAbstractMemLoc.FREE_VAR1.with_offset(offset)) + elif in_free_var(MemoryPositions.FREE_VAR_SPACE2, read_op.value): + offset = read_op.value - MemoryPositions.FREE_VAR_SPACE2 + _update_read_location(inst, IRAbstractMemLoc.FREE_VAR2.with_offset(offset)) + + +def get_memory_write_op(inst) -> IROperand | None: + opcode = inst.opcode + if opcode == "mstore": + dst = inst.operands[1] + return dst + elif opcode in ("mcopy", "calldatacopy", "dloadbytes", "codecopy", "returndatacopy"): + _, _, dst = inst.operands + return dst + elif opcode == "call": + _, dst, _, _, _, _, _ = inst.operands + return dst + elif opcode in ("delegatecall", "staticcall"): + _, dst, _, _, _, _ = inst.operands + return dst + elif opcode == "extcodecopy": + _, _, dst, _ = inst.operands + return dst + + return None + + +def get_write_size(inst: IRInstruction) -> IROperand | None: + opcode = inst.opcode + if opcode == "mstore": + return IRLiteral(32) + elif opcode in ("mcopy", "calldatacopy", "dloadbytes", "codecopy", "returndatacopy"): + size, _, _ = inst.operands + return size + elif opcode == "call": + size, _, _, _, _, _, _ = inst.operands + return size + elif opcode in ("delegatecall", "staticcall"): + size, _, _, _, _, _ = inst.operands + return size + elif opcode == "extcodecopy": + size, _, _, _ = inst.operands + return size + + return None + + +def get_memory_read_op(inst) -> IROperand | None: + opcode = inst.opcode + if opcode == "mload": + return inst.operands[0] + elif opcode == "mcopy": + _, src, _ = inst.operands + return src + elif opcode == "call": + _, _, _, dst, _, _, _ = inst.operands + return dst + elif opcode in ("delegatecall", "staticcall"): + _, _, _, dst, _, _ = inst.operands + return dst + elif opcode == "return": + _, src = inst.operands + return src + elif opcode == "create": + _, src, _value = inst.operands + return src + elif opcode == "create2": + _salt, size, src, _value = inst.operands + return src + elif opcode == "sha3": + _, offset = inst.operands + return offset + elif opcode == "log": + _, src = inst.operands[-2:] + return src + elif opcode == "revert": + size, src = inst.operands + if size.value == 0: + return None + return src + + return None + + +def _get_read_size(inst: IRInstruction) -> IROperand | None: + opcode = inst.opcode + if opcode == "mload": + return IRLiteral(32) + elif opcode == "mcopy": + size, _, _ = inst.operands + return size + elif opcode == "call": + _, _, size, _, _, _, _ = inst.operands + return size + elif opcode in ("delegatecall", "staticcall"): + _, _, size, _, _, _ = inst.operands + return size + elif opcode == "return": + size, _ = inst.operands + return size + elif opcode == "create": + size, _, _ = inst.operands + return size + elif opcode == "create2": + _, size, _, _ = inst.operands + return size + elif opcode == "sha3": + size, _ = inst.operands + return size + elif opcode == "log": + size, _ = inst.operands[-2:] + return size + elif opcode == "revert": + size, _ = inst.operands + if size.value == 0: + return None + return size + + return None + + +def _update_write_location(inst, new_op: IROperand): + opcode = inst.opcode + if opcode == "mstore": + inst.operands[1] = new_op + elif opcode in ("mcopy", "calldatacopy", "dloadbytes", "codecopy", "returndatacopy"): + inst.operands[2] = new_op + elif opcode == "call": + inst.operands[1] = new_op + elif opcode in ("delegatecall", "staticcall"): + inst.operands[1] = new_op + elif opcode == "extcodecopy": + inst.operands[2] = new_op + + +def _update_read_location(inst, new_op: IROperand): + opcode = inst.opcode + if opcode == "mload": + inst.operands[0] = new_op + elif opcode == "mcopy": + inst.operands[1] = new_op + elif opcode == "call": + inst.operands[3] = new_op + elif opcode in ("delegatecall", "staticcall", "call"): + inst.operands[3] = new_op + elif opcode == "return": + inst.operands[1] = new_op + elif opcode == "create": + inst.operands[1] = new_op + elif opcode == "create2": + inst.operands[2] = new_op + elif opcode == "sha3": + inst.operands[1] = new_op + elif opcode == "log": + inst.operands[-1] = new_op + elif opcode == "revert": + inst.operands[1] = new_op diff --git a/vyper/venom/parser.py b/vyper/venom/parser.py index 55d27622c8..61379995ca 100644 --- a/vyper/venom/parser.py +++ b/vyper/venom/parser.py @@ -4,6 +4,7 @@ from lark import Lark, Transformer from vyper.venom.basicblock import ( + IRAbstractMemLoc, IRBasicBlock, IRInstruction, IRLabel, @@ -43,9 +44,10 @@ operands_list: operand ("," operand)* - operand: VAR_IDENT | CONST | label_ref + operand: VAR_IDENT | CONST | MEMLOC | label_ref VAR_IDENT: "%" (DIGIT|LETTER|"_"|":")+ + MEMLOC: "[" (DIGIT)+ "," (DIGIT)+ "]" # non-terminal rules for different contexts func_name: IDENT | ESCAPED_STRING @@ -211,7 +213,7 @@ def assignment(self, children) -> IRInstruction: if isinstance(value, IRInstruction): value.output = to return value - if isinstance(value, (IRLiteral, IRVariable, IRLabel)): + if isinstance(value, (IRLiteral, IRVariable, IRLabel, IRAbstractMemLoc)): return IRInstruction("assign", [value], output=to) raise TypeError(f"Unexpected value {value} of type {type(value)}") @@ -269,6 +271,13 @@ def CONST(self, val) -> IRLiteral: return IRLiteral(int(val, 16)) return IRLiteral(int(val)) + def MEMLOC(self, memloc_ident) -> IRAbstractMemLoc: + data: str = memloc_ident[1:][:-1] + _id_str, size_str = data.split(",") + _id = int(_id_str) + size = int(size_str) + return IRAbstractMemLoc(size, force_id=_id) + def IDENT(self, val) -> str: return val.value diff --git a/vyper/venom/passes/__init__.py b/vyper/venom/passes/__init__.py index 06f8c099eb..54d97b04da 100644 --- a/vyper/venom/passes/__init__.py +++ b/vyper/venom/passes/__init__.py @@ -3,8 +3,10 @@ from .branch_optimization import BranchOptimizationPass from .cfg_normalization import CFGNormalization from .common_subexpression_elimination import CSE +from .concretize_mem_loc import ConcretizeMemLocPass from .dead_store_elimination import DeadStoreElimination from .dft import DFTPass +from .fixcalloca import FixCalloca from .float_allocas import FloatAllocas from .function_inliner import FunctionInlinerPass from .literals_codesize import ReduceLiteralsCodesize diff --git a/vyper/venom/passes/algebraic_optimization.py b/vyper/venom/passes/algebraic_optimization.py index e04dc02fc8..fd0c345576 100644 --- a/vyper/venom/passes/algebraic_optimization.py +++ b/vyper/venom/passes/algebraic_optimization.py @@ -179,6 +179,11 @@ def _handle_inst_peephole(self, inst: IRInstruction): # no more cases for this instruction return + if inst.opcode == "gep": + if lit_eq(inst.operands[1], 0): + self.updater.mk_assign(inst, inst.operands[0]) + return + if inst.opcode in {"add", "sub", "xor"}: # (x - x) == (x ^ x) == 0 if inst.opcode in ("xor", "sub") and operands[0] == operands[1]: diff --git a/vyper/venom/passes/concretize_mem_loc.py b/vyper/venom/passes/concretize_mem_loc.py new file mode 100644 index 0000000000..bbb5d5e0a2 --- /dev/null +++ b/vyper/venom/passes/concretize_mem_loc.py @@ -0,0 +1,227 @@ +from collections import defaultdict + +from vyper.utils import OrderedSet +from vyper.venom.analysis import CFGAnalysis, DFGAnalysis +from vyper.venom.basicblock import ( + IRAbstractMemLoc, + IRBasicBlock, + IRInstruction, + IRLabel, + IRLiteral, + IROperand, + IRVariable, +) +from vyper.venom.function import IRFunction +from vyper.venom.memory_allocator import MemoryAllocator +from vyper.venom.memory_location import get_memory_read_op, get_memory_write_op, get_write_size +from vyper.venom.passes.base_pass import IRPass + + +class ConcretizeMemLocPass(IRPass): + allocated_in_bb: dict[IRBasicBlock, int] + + def run_pass(self): + self.allocator = self.function.ctx.mem_allocator + self.cfg = self.analyses_cache.request_analysis(CFGAnalysis) + self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) + + self.allocator.start_fn_allocation() + + orig = self.allocator.curr + + self.mem_liveness = MemLiveness(self.function, self.cfg, self.dfg, self.allocator) + self.mem_liveness.analyze() + + livesets = list(self.mem_liveness.livesets.items()) + already_allocated = [item for item in livesets if item[0]._id in self.allocator.allocated] + livesets = [item for item in livesets if item[0]._id not in self.allocator.allocated] + livesets.sort(key=lambda x: len(x[1]), reverse=False) + + max_curr = 0 + for index, (mem, insts) in enumerate(livesets): + curr = orig + for before_mem, before_insts in already_allocated: + if len(OrderedSet.intersection(insts, before_insts)) == 0: + continue + place = self.allocator.allocated[before_mem._id] + curr = max(place[0] + place[1], curr) + for i in range(index): + before_mem, before_insts = livesets[i] + if len(OrderedSet.intersection(insts, before_insts)) == 0: + continue + place = self.allocator.allocated[before_mem._id] + curr = max(place[0] + place[1], curr) + self.allocator.curr = curr + self.allocator.allocate(mem) + max_curr = max(self.allocator.curr, max_curr) + + self.allocator.curr = max_curr + + for bb in self.function.get_basic_blocks(): + self._handle_bb(bb) + + all_allocated = [item[0] for item in already_allocated] + all_allocated.extend([item[0] for item in livesets]) + + self.allocator.end_fn_allocation(all_allocated, fn=self.function) + + self.analyses_cache.invalidate_analysis(DFGAnalysis) + + def _handle_bb(self, bb: IRBasicBlock): + for inst in bb.instructions: + if inst.opcode == "codecopyruntime": + inst.opcode = "codecopy" + continue + new_ops = [self._handle_op(op) for op in inst.operands] + inst.operands = new_ops + if inst.opcode == "gep": + inst.opcode = "add" + elif inst.opcode == "mem_deploy_start": + inst.opcode = "assign" + + def _handle_op(self, op: IROperand) -> IROperand: + if isinstance(op, IRAbstractMemLoc) and op._id in self.allocator.allocated: + return IRLiteral(self.allocator.allocated[op._id][0] + op.offset) + elif isinstance(op, IRAbstractMemLoc): + return IRLiteral(self.allocator.allocate(op) + op.offset) + else: + return op + + +# +_CALL_OPCODES = frozenset(["invoke", "staticcall", "call", "delegatecall"]) + + +class MemLiveness: + function: IRFunction + cfg: CFGAnalysis + mem_allocator: MemoryAllocator + + liveat: dict[IRInstruction, OrderedSet[IRAbstractMemLoc]] + livesets: dict[IRAbstractMemLoc, OrderedSet[IRInstruction]] + + used: dict[IRInstruction, OrderedSet[IRAbstractMemLoc]] + + def __init__( + self, + function: IRFunction, + cfg: CFGAnalysis, + dfg: DFGAnalysis, + mem_allocator: MemoryAllocator, + ): + self.function = function + self.cfg = cfg + self.dfg = dfg + self.used = defaultdict(OrderedSet) + self.liveat = defaultdict(OrderedSet) + self.mem_allocator = mem_allocator + + def analyze(self): + found = False + upper_bound = len(list(self.function.get_basic_blocks())) ** 2 + 1 + for _ in range(upper_bound): + change = False + for bb in self.cfg.dfs_post_walk: + change |= self._handle_bb(bb) + change |= self._handle_used(bb) + + if not change: + found = True + break + + assert found, self.function + + self.livesets = defaultdict(OrderedSet) + for inst, mems in self.liveat.items(): + for mem in mems: + if mem in self.used[inst]: + self.livesets[mem].add(inst) + + def _handle_bb(self, bb: IRBasicBlock) -> bool: + curr: OrderedSet[IRAbstractMemLoc] = OrderedSet() + if len(succs := self.cfg.cfg_out(bb)) > 0: + for other in (self.liveat[succ.instructions[0]] for succ in succs): + curr.update(other) + + before = self.liveat[bb.instructions[0]] + + for inst in reversed(bb.instructions): + write_op = get_memory_write_op(inst) + write_ops = self._follow_op(write_op) + read_op = get_memory_read_op(inst) + read_ops = self._follow_op(read_op) + + for read_op in read_ops: + assert isinstance(read_op, IRAbstractMemLoc) + curr.add(read_op.no_offset()) + + if inst.opcode == "invoke": + label = inst.operands[0] + assert isinstance(label, IRLabel) + fn = self.function.ctx.get_function(label) + curr.addmany(self.mem_allocator.mems_used[fn]) + + if inst.opcode in _CALL_OPCODES: + for op in inst.operands: + if not isinstance(op, IRAbstractMemLoc): + continue + curr.add(op.no_offset()) + + self.liveat[inst] = curr.copy() + + for write_op in write_ops: + assert isinstance(write_op, IRAbstractMemLoc) + size = get_write_size(inst) + assert size is not None + if not isinstance(size, IRLiteral): + continue + if write_op in curr and size.value == write_op.size: + curr.remove(write_op.no_offset()) + if write_op._id in (op._id for op in read_ops): + curr.add(write_op.no_offset()) + + if before != self.liveat[bb.instructions[0]]: + return True + + return False + + def _handle_used(self, bb: IRBasicBlock) -> bool: + curr: OrderedSet[IRAbstractMemLoc] = OrderedSet(self.function.allocated_args.values()) + if len(succs := self.cfg.cfg_in(bb)) > 0: + for other in (self.used[succ.instructions[-1]] for succ in succs): + curr.update(other) + + before = self.used[bb.instructions[-1]] + for inst in bb.instructions: + for op in inst.operands: + if not isinstance(op, IRAbstractMemLoc): + continue + curr.add(op.no_offset()) + if inst.opcode == "invoke": + label = inst.operands[0] + assert isinstance(label, IRLabel) + fn = self.function.ctx.get_function(label) + curr.addmany(self.mem_allocator.mems_used[fn]) + self.used[inst] = curr.copy() + return before != curr + + def _follow_op(self, op: IROperand | None) -> set[IRAbstractMemLoc]: + if op is None: + return set() + if isinstance(op, IRAbstractMemLoc): + return {op} + if not isinstance(op, IRVariable): + return set() + + inst = self.dfg.get_producing_instruction(op) + assert inst is not None + if inst.opcode == "gep": + mem = inst.operands[0] + return self._follow_op(mem) + elif inst.opcode == "phi": + res = set() + for _, var in inst.phi_operands: + src = self._follow_op(var) + res.update(src) + return res + return set() diff --git a/vyper/venom/passes/fixcalloca.py b/vyper/venom/passes/fixcalloca.py new file mode 100644 index 0000000000..12a5dd5994 --- /dev/null +++ b/vyper/venom/passes/fixcalloca.py @@ -0,0 +1,55 @@ +from collections import deque + +from vyper.venom.analysis import DFGAnalysis, FCGAnalysis +from vyper.venom.basicblock import IRInstruction, IRLabel, IRLiteral +from vyper.venom.function import IRFunction +from vyper.venom.passes.base_pass import IRGlobalPass +from vyper.venom.passes.machinery.inst_updater import InstUpdater + + +class FixCalloca(IRGlobalPass): + def run_pass(self): + for fn in self.ctx.get_functions(): + self.fcg = self.analyses_caches[fn].request_analysis(FCGAnalysis) + self.dfg = self.analyses_caches[fn].request_analysis(DFGAnalysis) + self.updater = InstUpdater(self.dfg) + self._handle_fn(fn) + + def _handle_fn(self, fn: IRFunction): + for bb in fn.get_basic_blocks(): + for inst in bb.instructions: + if inst.opcode != "calloca": + continue + + assert inst.output is not None + assert len(inst.operands) == 3 + size, _id, callsite = inst.operands + assert isinstance(callsite, IRLabel) + assert isinstance(_id, IRLiteral) + + called_name = callsite.value.rsplit("_call", maxsplit=1)[0] + + called = self.ctx.get_function(IRLabel(called_name)) + if _id.value not in called.allocated_args: + self._removed_unused_calloca(inst) + continue + memloc = called.allocated_args[_id.value] + + inst.operands = [memloc, _id] + + def _removed_unused_calloca(self, inst: IRInstruction): + assert inst.output is not None + to_remove = set() + worklist: deque = deque() + worklist.append(inst) + while len(worklist) > 0: + curr = worklist.popleft() + if curr in to_remove: + continue + to_remove.add(curr) + + if curr.output is not None: + uses = self.dfg.get_uses(curr.output) + worklist.extend(uses) + + self.updater.nop_multi(to_remove) diff --git a/vyper/venom/passes/function_inliner.py b/vyper/venom/passes/function_inliner.py index 6d6f355d2b..e0eea0e218 100644 --- a/vyper/venom/passes/function_inliner.py +++ b/vyper/venom/passes/function_inliner.py @@ -108,7 +108,8 @@ def _inline_function(self, func: IRFunction, call_sites: List[IRInstruction]) -> # inlined any callsites (see demotion of calloca # to alloca below). this handles both cases. if inst.opcode in ("alloca", "calloca"): - _, _, alloca_id_op = inst.operands + assert len(inst.operands) >= 2, inst + alloca_id_op = inst.operands[1] alloca_id = alloca_id_op.value assert isinstance(alloca_id, int) # help mypy if alloca_id in callocas: @@ -122,7 +123,7 @@ def _inline_function(self, func: IRFunction, call_sites: List[IRInstruction]) -> callocas[alloca_id] = inst if inst.opcode == "palloca": - _, _, alloca_id_op = inst.operands + _, alloca_id_op = inst.operands alloca_id = alloca_id_op.value assert isinstance(alloca_id, int) if alloca_id not in callocas: @@ -139,7 +140,7 @@ def _inline_function(self, func: IRFunction, call_sites: List[IRInstruction]) -> for inst in bb.instructions: if inst.opcode != "calloca": continue - _, _, alloca_id = inst.operands + _, alloca_id = inst.operands if alloca_id in found: # demote to alloca so that mem2var will work inst.opcode = "alloca" diff --git a/vyper/venom/passes/load_elimination.py b/vyper/venom/passes/load_elimination.py index 1b051e8e03..cde0efe39b 100644 --- a/vyper/venom/passes/load_elimination.py +++ b/vyper/venom/passes/load_elimination.py @@ -3,23 +3,49 @@ from vyper.utils import OrderedSet from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, LivenessAnalysis from vyper.venom.analysis.analysis import IRAnalysis -from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLiteral, IROperand, IRVariable +from vyper.venom.basicblock import ( + IRAbstractMemLoc, + IRBasicBlock, + IRInstruction, + IRLiteral, + IROperand, + IRVariable, +) from vyper.venom.effects import Effects from vyper.venom.passes.base_pass import InstUpdater, IRPass Lattice = dict[IROperand, OrderedSet[IROperand]] -def _conflict(store_opcode: str, k1: IRLiteral, k2: IRLiteral): - ptr1, ptr2 = k1.value, k2.value - # hardcode the size of store opcodes for now. maybe refactor to use - # vyper.evm.address_space +def _conflict_lit(store_opcode: str, ptr1: int, ptr2: int): if store_opcode == "mstore": return abs(ptr1 - ptr2) < 32 assert store_opcode in ("sstore", "tstore"), "unhandled store opcode" return abs(ptr1 - ptr2) < 1 +def _conflict( + store_opcode: str, k1: IRLiteral | IRAbstractMemLoc, k2: IRLiteral | IRAbstractMemLoc, tmp=None +): + # hardcode the size of store opcodes for now. maybe refactor to use + # vyper.evm.address_space + if store_opcode == "mstore": + if isinstance(k1, IRLiteral) and isinstance(k2, IRLiteral): + return _conflict_lit(store_opcode, k1.value, k2.value) + if not isinstance(k1, IRAbstractMemLoc) or not isinstance(k2, IRAbstractMemLoc): + # this used to be assert and it triggered the error + # with --enable-compiler-debug-mode why + return True + if k1._id == k2._id: + return _conflict_lit(store_opcode, k1.offset, k2.offset) + else: + return False + + assert isinstance(k1, IRLiteral) and isinstance(k2, IRLiteral) + ptr1, ptr2 = k1.value, k2.value + return _conflict_lit(store_opcode, ptr1, ptr2) + + class LoadAnalysis(IRAnalysis): InstToLattice = dict[IRInstruction, Lattice] lattice: dict[Effects | str, InstToLattice] @@ -71,8 +97,10 @@ def _merge(self, bb: IRBasicBlock) -> Lattice: return res - def get_literal(self, op): + def get_memloc(self, op): op = self.dfg._traverse_assign_chain(op) + if isinstance(op, IRAbstractMemLoc): + return op if isinstance(op, IRLiteral): return op return None @@ -92,7 +120,7 @@ def _handle_bb( self.inst_to_lattice[inst] = lattice.copy() # mstore [val, ptr] val, ptr = inst.operands - lit = self.get_literal(ptr) + lit = self.get_memloc(ptr) if lit is None: lattice.clear() lattice[ptr] = OrderedSet([val]) @@ -102,7 +130,7 @@ def _handle_bb( # kick out any conflicts for existing_key in lattice.copy().keys(): - existing_lit = self.get_literal(existing_key) + existing_lit = self.get_memloc(existing_key) if existing_lit is None: # a variable in the lattice. assign this ptr in the lattice # and flush everything else. @@ -137,7 +165,7 @@ def run_pass(self): self.cfg = self.analyses_cache.request_analysis(CFGAnalysis) self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) self.updater = InstUpdater(self.dfg) - self.load_analysis = self.analyses_cache.request_analysis(LoadAnalysis) + self.load_analysis = self.analyses_cache.force_analysis(LoadAnalysis) self._run(Effects.MEMORY, "mload", "mstore") self._run(Effects.TRANSIENT, "tload", "tstore") diff --git a/vyper/venom/passes/mem2var.py b/vyper/venom/passes/mem2var.py index bec17d0393..a5d47da52c 100644 --- a/vyper/venom/passes/mem2var.py +++ b/vyper/venom/passes/mem2var.py @@ -1,6 +1,6 @@ from vyper.utils import all2 from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, LivenessAnalysis -from vyper.venom.basicblock import IRInstruction, IRVariable +from vyper.venom.basicblock import IRAbstractMemLoc, IRInstruction, IROperand, IRVariable from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ENABLE_NEW_CALL_CONV from vyper.venom.passes.base_pass import InstUpdater, IRPass @@ -15,9 +15,11 @@ class Mem2Var(IRPass): function: IRFunction def run_pass(self): + self.mem_alloc = self.function.ctx.mem_allocator self.analyses_cache.request_analysis(CFGAnalysis) dfg = self.analyses_cache.request_analysis(DFGAnalysis) self.updater = InstUpdater(dfg) + self.dfg = dfg self.var_name_count = 0 for var, inst in dfg.outputs.copy().items(): @@ -25,6 +27,8 @@ def run_pass(self): self._process_alloca_var(dfg, inst, var) elif inst.opcode == "palloca": self._process_palloca_var(dfg, inst, var) + elif inst.opcode == "calloca": + self._process_calloca(inst) self.analyses_cache.invalidate_analysis(LivenessAnalysis) @@ -34,39 +38,65 @@ def _mk_varname(self, varname: str, alloca_id: int): self.var_name_count += 1 return varname - def _process_alloca_var(self, dfg: DFGAnalysis, alloca_inst, var: IRVariable): + def _process_alloca_var(self, dfg: DFGAnalysis, alloca_inst: IRInstruction, var: IRVariable): """ Process alloca allocated variable. If it is only used by mstore/mload/return instructions, it is promoted to a stack variable. Otherwise, it is left as is. """ - uses = dfg.get_uses(var) - if not all2(inst.opcode in ["mstore", "mload", "return"] for inst in uses): - return - alloca_id = alloca_inst.operands[2] + assert len(alloca_inst.operands) == 2, (alloca_inst, alloca_inst.parent) + + mem_loc, alloca_id = alloca_inst.operands var_name = self._mk_varname(var.value, alloca_id.value) var = IRVariable(var_name) + assert alloca_inst.output is not None + uses = dfg.get_uses(alloca_inst.output) + + self.updater.mk_assign(alloca_inst, mem_loc) + + if any(inst.opcode == "add" for inst in uses): + self._fix_adds(alloca_inst, mem_loc) + return + + if not all2(inst.opcode in ["mstore", "mload", "return"] for inst in uses): + return + + assert isinstance(mem_loc, IRAbstractMemLoc) + size = mem_loc.size + for inst in uses.copy(): if inst.opcode == "mstore": - self.updater.mk_assign(inst, inst.operands[0], new_output=var) + if size <= 32: + self.updater.mk_assign(inst, inst.operands[0], new_output=var) + else: + self.updater.update_operands(inst, {alloca_inst.output: mem_loc}) elif inst.opcode == "mload": - self.updater.mk_assign(inst, var) + if size <= 32: + self.updater.mk_assign(inst, var) + else: + self.updater.update_operands(inst, {alloca_inst.output: mem_loc}) elif inst.opcode == "return": - self.updater.add_before(inst, "mstore", [var, inst.operands[1]]) + if size <= 32: + self.updater.add_before(inst, "mstore", [var, mem_loc]) + inst.operands[1] = mem_loc def _process_palloca_var(self, dfg: DFGAnalysis, palloca_inst: IRInstruction, var: IRVariable): """ Process alloca allocated variable. If it is only used by mstore/mload instructions, it is promoted to a stack variable. Otherwise, it is left as is. """ - uses = dfg.get_uses(var) - if not all2(inst.opcode in ["mstore", "mload"] for inst in uses): + mem_loc, alloca_id = palloca_inst.operands + assert palloca_inst.output is not None + uses = dfg.get_uses(palloca_inst.output) + + self.updater.mk_assign(palloca_inst, mem_loc) + if any(inst.opcode == "add" for inst in uses): + self._fix_adds(palloca_inst, mem_loc) return - ofst, _size, alloca_id = palloca_inst.operands - var_name = self._mk_varname(var.value, alloca_id.value) - var = IRVariable(var_name) + if not all2(inst.opcode in ["mstore", "mload"] for inst in uses): + return # some value given to us by the calling convention fn = self.function @@ -75,15 +105,46 @@ def _process_palloca_var(self, dfg: DFGAnalysis, palloca_inst: IRInstruction, va # on alloca_id) is a bit kludgey, but we will live. param = fn.get_param_by_id(alloca_id.value) if param is None: - self.updater.update(palloca_inst, "mload", [ofst], new_output=var) + self.updater.update(palloca_inst, "mload", [mem_loc], new_output=var) else: self.updater.update(palloca_inst, "assign", [param.func_var], new_output=var) else: # otherwise, it comes from memory, convert to an mload. - self.updater.update(palloca_inst, "mload", [ofst], new_output=var) + self.updater.update(palloca_inst, "mload", [mem_loc], new_output=var) + + assert isinstance(mem_loc, IRAbstractMemLoc) + size = mem_loc.size for inst in uses.copy(): if inst.opcode == "mstore": - self.updater.mk_assign(inst, inst.operands[0], new_output=var) + if size <= 32: + self.updater.mk_assign(inst, inst.operands[0], new_output=var) + else: + self.updater.update_operands(inst, {palloca_inst.output: mem_loc}) elif inst.opcode == "mload": - self.updater.mk_assign(inst, var) + if size <= 32: + self.updater.mk_assign(inst, var) + else: + self.updater.update_operands(inst, {palloca_inst.output: mem_loc}) + + def _process_calloca(self, inst: IRInstruction): + assert inst.opcode == "calloca" + assert inst.output is not None + assert len(inst.operands) == 2 + memloc = inst.operands[0] + + assert isinstance(memloc, IRAbstractMemLoc) + + self.updater.mk_assign(inst, memloc) + self._fix_adds(inst, memloc) + + def _fix_adds(self, mem_src: IRInstruction, mem_op: IROperand): + assert mem_src.output is not None + uses = self.dfg.get_uses(mem_src.output) + for inst in uses.copy(): + if inst.opcode != "add": + continue + other = [op for op in inst.operands if op != mem_src.output] + assert len(other) == 1 + self.updater.update(inst, "gep", [mem_op, other[0]]) + self._fix_adds(inst, inst.output) diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index ee9580c450..eeafc6faac 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -7,6 +7,7 @@ from vyper.utils import OrderedSet from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, IRAnalysesCache, LivenessAnalysis from vyper.venom.basicblock import ( + IRAbstractMemLoc, IRBasicBlock, IRInstruction, IRLabel, @@ -36,7 +37,7 @@ class FlowWorkItem: WorkListItem = Union[FlowWorkItem, SSAWorkListItem] -LatticeItem = Union[LatticeEnum, IRLiteral, IRLabel] +LatticeItem = Union[LatticeEnum, IRLiteral, IRLabel, IRAbstractMemLoc] Lattice = dict[IRVariable, LatticeItem] @@ -155,7 +156,7 @@ def _set_lattice(self, op: IROperand, value: LatticeItem): self.lattice[op] = value def _eval_from_lattice(self, op: IROperand) -> LatticeItem: - if isinstance(op, (IRLiteral, IRLabel)): + if isinstance(op, (IRLiteral, IRLabel, IRAbstractMemLoc)): return op assert isinstance(op, IRVariable), f"Not a variable: {op}" @@ -189,6 +190,16 @@ def _visit_expr(self, inst: IRInstruction): out = self._eval_from_lattice(inst.operands[0]) self._set_lattice(inst.output, out) self._add_ssa_work_items(inst) + elif opcode == "gep": + assert inst.output is not None, inst + mem = self._eval_from_lattice(inst.operands[0]) + offset = self._eval_from_lattice(inst.operands[1]) + if not isinstance(mem, IRAbstractMemLoc) or not isinstance(offset, IRLiteral): + out = LatticeEnum.BOTTOM + else: + out = IRAbstractMemLoc(mem.size, offset=mem.offset + offset.value, force_id=mem._id) + self._set_lattice(inst.output, out) + self._add_ssa_work_items(inst) elif opcode == "jmp": target = self.fn.get_basic_block(inst.operands[0].value) self.work_list.append(FlowWorkItem(inst.parent, target)) @@ -242,7 +253,7 @@ def finalize(ret): ops: list[IRLiteral] = [] for op in inst.operands: # Evaluate the operand according to the lattice - if isinstance(op, IRLabel): + if isinstance(op, (IRLabel, IRAbstractMemLoc)): return finalize(LatticeEnum.BOTTOM) elif isinstance(op, IRVariable): eval_result = self.lattice[op] @@ -260,7 +271,10 @@ def finalize(ret): if eval_result is LatticeEnum.TOP: return finalize(LatticeEnum.TOP) - assert isinstance(eval_result, IRLiteral), (op, eval_result, inst.parent.label, inst) + if isinstance(eval_result, IRAbstractMemLoc): + return finalize(LatticeEnum.BOTTOM) + + assert isinstance(eval_result, IRLiteral), (inst.parent.label, op, inst, eval_result) ops.append(eval_result) # If we haven't found BOTTOM yet, evaluate the operation @@ -319,7 +333,7 @@ def _replace_constants(self, inst: IRInstruction): for i, op in enumerate(inst.operands): if isinstance(op, IRVariable): lat = self.lattice[op] - if isinstance(lat, IRLiteral): + if isinstance(lat, (IRLiteral, IRAbstractMemLoc)): inst.operands[i] = lat diff --git a/vyper/venom/passes/simplify_cfg.py b/vyper/venom/passes/simplify_cfg.py index bf6306a91a..c3a588c0fe 100644 --- a/vyper/venom/passes/simplify_cfg.py +++ b/vyper/venom/passes/simplify_cfg.py @@ -45,6 +45,13 @@ def _merge_jump(self, a: IRBasicBlock, b: IRBasicBlock): self.cfg.remove_cfg_in(next_bb, b) self.cfg.add_cfg_in(next_bb, a) + for next_bb in self.cfg.cfg_out(a): + for inst in next_bb.instructions: + # assume phi instructions are at beginning of bb + if inst.opcode != "phi": + break + inst.operands[inst.operands.index(b.label)] = a.label + self.function.remove_basic_block(b) def _collapse_chained_blocks_r(self, bb: IRBasicBlock): diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index a4a2de0666..df3084e758 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -141,7 +141,7 @@ def _ofst(label: Label, value: int) -> list[Any]: # with the assembler. My suggestion is to let this be for now, and we can # refactor it later when we are finished phasing out the old IR. class VenomCompiler: - ctxs: list[IRContext] + ctx: IRContext label_counter = 0 visited_basicblocks: OrderedSet # {IRBasicBlock} liveness: LivenessAnalysis @@ -396,11 +396,6 @@ def _generate_evm_for_instruction( if opcode in ["jmp", "djmp", "jnz", "invoke"]: operands = list(inst.get_non_label_operands()) - elif opcode in ("alloca", "palloca", "calloca"): - assert len(inst.operands) == 3, inst - offset, _size, _id = inst.operands - operands = [offset] - # iload and istore are special cases because they can take a literal # that is handled specialy with the _OFST macro. Look below, after the # stack reordering.