diff --git a/tests/functional/venom/parser/test_multi_output_invoke.py b/tests/functional/venom/parser/test_multi_output_invoke.py new file mode 100644 index 0000000000..d8ff51fd27 --- /dev/null +++ b/tests/functional/venom/parser/test_multi_output_invoke.py @@ -0,0 +1,31 @@ +from tests.venom_utils import parse_venom +from vyper.venom.basicblock import IRInstruction + + +def _make_src(a: int, b: int) -> str: + return f""" +function main {{ +main: + %x, %y = invoke @f + sink %x, %y +}} + +function f {{ +f: + %retpc = param + %v0 = assign {a} + %v1 = assign {b} + ret %v0, %v1, %retpc +}} +""" + + +def test_parse_multi_output_invoke_builds_two_outputs(): + src = _make_src(7, 9) + ctx = parse_venom(src) + fn = ctx.get_function(next(iter(ctx.functions.keys()))) + main_bb = fn.get_basic_block("main") + inst = next(inst for inst in main_bb.instructions if inst.opcode == "invoke") + assert isinstance(inst, IRInstruction) + outs = inst.get_outputs() + assert len(outs) == 2 diff --git a/tests/functional/venom/parser/test_parsing.py b/tests/functional/venom/parser/test_parsing.py index a45fb4a658..b168b8d93d 100644 --- a/tests/functional/venom/parser/test_parsing.py +++ b/tests/functional/venom/parser/test_parsing.py @@ -151,7 +151,7 @@ def test_multi_function(): expected_ctx.add_function(entry_fn := IRFunction(IRLabel("entry"))) entry_bb = entry_fn.get_basic_block("entry") - entry_bb.append_invoke_instruction([IRLabel("check_cv")], returns=False) + entry_bb.append_invoke_instruction([IRLabel("check_cv")], returns=0) entry_bb.append_instruction("jmp", IRLabel("wow")) entry_fn.append_basic_block(wow_bb := IRBasicBlock(IRLabel("wow"), entry_fn)) @@ -213,7 +213,7 @@ def test_multi_function_and_data(): expected_ctx.add_function(entry_fn := IRFunction(IRLabel("entry"))) entry_bb = entry_fn.get_basic_block("entry") - entry_bb.append_invoke_instruction([IRLabel("check_cv")], returns=False) + entry_bb.append_invoke_instruction([IRLabel("check_cv")], returns=0) entry_bb.append_instruction("jmp", IRLabel("wow")) entry_fn.append_basic_block(wow_bb := IRBasicBlock(IRLabel("wow"), entry_fn)) @@ -366,3 +366,37 @@ def test_phis(): parsed_fn = next(iter(ctx.functions.values())) assert_bb_eq(parsed_fn.get_basic_block(expect_bb.label.name), expect_bb) + + +def test_multi_output_last_var(): + source = """ + function main { + main: + %1, %2 = invoke @f + %3, %4, %5 = invoke @g + sink %1, %2, %3, %4, %5 + } + + function f { + f: + %retpc = param + ret 10, 20, %retpc + } + + function g { + g: + %retpc = param + ret 30, 40, 50, %retpc + } + """ + + parsed_ctx = parse_venom(source) + + main_fn = parsed_ctx.get_function(IRLabel("main")) + assert main_fn.last_variable == 5 + + f_fn = parsed_ctx.get_function(IRLabel("f")) + assert f_fn.last_variable == 0 + + g_fn = parsed_ctx.get_function(IRLabel("g")) + assert g_fn.last_variable == 0 diff --git a/tests/unit/compiler/venom/test_calling_convention.py b/tests/unit/compiler/venom/test_calling_convention.py new file mode 100644 index 0000000000..199ffd945f --- /dev/null +++ b/tests/unit/compiler/venom/test_calling_convention.py @@ -0,0 +1,163 @@ +import pytest + +from tests.venom_utils import parse_venom +from vyper.venom.check_venom import ( + InconsistentReturnArity, + InvokeArityMismatch, + MultiOutputNonInvoke, + check_calling_convention, +) + + +def _assert_raises(exc_group, exc_type): + assert any(isinstance(err, exc_type) for err in exc_group.exceptions) + + +def test_invoke_arity_match_zero(): + src = """ + function main { + main: + %p = source + invoke @f, %p + } + + function f { + main: + %p = param + ret @retpc + } + """ + ctx = parse_venom(src) + # Should not raise: callee returns 0, call site binds 0 + check_calling_convention(ctx) + + +def test_invoke_arity_match_one(): + src = """ + function main { + main: + %p = source + %ret = invoke @f, %p + sink %ret + } + + function f { + main: + %p = param + %one = add %p, 1 + ret %one, @retpc + } + """ + ctx = parse_venom(src) + # Should not raise: callee returns 1, call site binds 1 + check_calling_convention(ctx) + + +def test_invoke_arity_mismatch_too_few_outputs(): + src = """ + function main { + main: + %p = source + invoke @f, %p + } + + function f { + main: + %p = param + %one = add %p, 1 + ret %one, @retpc + } + """ + ctx = parse_venom(src) + with pytest.raises(ExceptionGroup) as excinfo: + check_calling_convention(ctx) + _assert_raises(excinfo.value, InvokeArityMismatch) + + +def test_invoke_arity_mismatch_too_many_outputs(): + src = """ + function main { + main: + %p = source + %ret = invoke @f, %p + sink %ret + } + + function f { + main: + %p = param + ret @retpc + } + """ + ctx = parse_venom(src) + with pytest.raises(ExceptionGroup) as excinfo: + check_calling_convention(ctx) + _assert_raises(excinfo.value, InvokeArityMismatch) + + +def test_inconsistent_callee_return_arity(): + src = """ + function main { + main: + %p = source + invoke @f, %p + } + + function f { + entry: + %p = param + jnz %p, @then, @else + then: + %one = add %p, 1 + ret %one, @retpc + else: + ret @retpc + } + """ + ctx = parse_venom(src) + with pytest.raises(ExceptionGroup) as excinfo: + check_calling_convention(ctx) + _assert_raises(excinfo.value, InconsistentReturnArity) + + +def test_inconsistent_callee_return_arity_no_spurious_mismatch(): + # When callee has inconsistent return arity, we should only report + # InconsistentReturnArity, not InvokeArityMismatch for the call site. + src = """ + function main { + main: + %p = source + %ret = invoke @f, %p + sink %ret + } + + function f { + entry: + %p = param + jnz %p, @then, @else + then: + %one = add %p, 1 + ret %one, @retpc + else: + ret @retpc + } + """ + ctx = parse_venom(src) + with pytest.raises(ExceptionGroup) as excinfo: + check_calling_convention(ctx) + _assert_raises(excinfo.value, InconsistentReturnArity) + assert not any(isinstance(err, InvokeArityMismatch) for err in excinfo.value.exceptions) + + +def test_multi_lhs_non_invoke_rejected(): + src = """ + function main { + main: + %x, %y = add 1, 2 + sink %x, %y + } + """ + ctx = parse_venom(src) + with pytest.raises(ExceptionGroup) as excinfo: + check_calling_convention(ctx) + _assert_raises(excinfo.value, MultiOutputNonInvoke) diff --git a/tests/unit/compiler/venom/test_dominator_tree.py b/tests/unit/compiler/venom/test_dominator_tree.py index 30a2e4564e..dc2f8f26f2 100644 --- a/tests/unit/compiler/venom/test_dominator_tree.py +++ b/tests/unit/compiler/venom/test_dominator_tree.py @@ -66,8 +66,8 @@ def test_phi_placement(): bb1, bb2, bb3, bb4, bb5, bb6, bb7 = [fn.get_basic_block(str(i)) for i in range(1, 8)] x = IRVariable("%x") - bb1.insert_instruction(IRInstruction("mload", [IRLiteral(0)], x), 0) - bb2.insert_instruction(IRInstruction("add", [x, IRLiteral(1)], x), 0) + bb1.insert_instruction(IRInstruction("mload", [IRLiteral(0)], [x]), 0) + bb2.insert_instruction(IRInstruction("add", [x, IRLiteral(1)], [x]), 0) bb7.insert_instruction(IRInstruction("mstore", [x, IRLiteral(0)]), 0) ac = IRAnalysesCache(fn) diff --git a/tests/unit/compiler/venom/test_invoke_multi_return.py b/tests/unit/compiler/venom/test_invoke_multi_return.py new file mode 100644 index 0000000000..8a642a9557 --- /dev/null +++ b/tests/unit/compiler/venom/test_invoke_multi_return.py @@ -0,0 +1,38 @@ +import pytest + +from tests.hevm import hevm_check_venom_ctx +from vyper.venom.parser import parse_venom + + +@pytest.mark.hevm +def test_invoke_two_returns_executes_correctly(): + a, b = 7, 9 + + pre = parse_venom( + f""" + function main {{ + main: + %a, %b = invoke @f + sink %a, %b + }} + + function f {{ + f: + %retpc = param + %v0 = {a} + %v1 = {b} + ret %v0, %v1, %retpc + }} + """ + ) + + post = parse_venom( + f""" + function main {{ + main: + sink {a}, {b} + }} + """ + ) + + hevm_check_venom_ctx(pre, post) diff --git a/tests/unit/compiler/venom/test_venom_to_assembly.py b/tests/unit/compiler/venom/test_venom_to_assembly.py index 16a130153a..51bda277de 100644 --- a/tests/unit/compiler/venom/test_venom_to_assembly.py +++ b/tests/unit/compiler/venom/test_venom_to_assembly.py @@ -39,6 +39,34 @@ def test_optimistic_swap_params(): assert asm == ["SWAP2", "PUSH1", 117, "POP", "MSTORE", "MSTORE", "JUMP"] +def test_invoke_middle_output_unused(): + code = """ + function main { + main: + %a, %b, %c = invoke @callee + return %a, %c + } + + function callee { + callee: + %retpc = param + %x = 1 + %y = 2 + %z = 3 + ret %x, %y, %z, %retpc + } + """ + ctx = parse_venom(code) + asm = VenomCompiler(ctx).generate_evm_assembly() + + assert "POP" in asm, f"expected POP to remove dead output, got {asm}" + pop_idx = asm.index("POP") + assert pop_idx > 0 and asm[pop_idx - 1] == "SWAP1", asm + assert "RETURN" in asm, asm + return_idx = asm.index("RETURN") + assert return_idx > pop_idx and asm[return_idx - 1] == "SWAP1", asm + + def test_popmany_bulk_removal_of_suffix(): compiler = VenomCompiler(IRContext()) stack = StackModel() diff --git a/tests/venom_utils.py b/tests/venom_utils.py index 5c3dd83d60..0213ec82ea 100644 --- a/tests/venom_utils.py +++ b/tests/venom_utils.py @@ -15,7 +15,11 @@ def parse_from_basic_block(source: str, funcname="_global"): def instructions_eq(i1: IRInstruction, i2: IRInstruction) -> bool: - return i1.output == i2.output and i1.opcode == i2.opcode and i1.operands == i2.operands + return ( + i1.get_outputs() == i2.get_outputs() + and i1.opcode == i2.opcode + and i1.operands == i2.operands + ) def assert_bb_eq(bb1: IRBasicBlock, bb2: IRBasicBlock): diff --git a/vyper/venom/__init__.py b/vyper/venom/__init__.py index 1b166bda77..62c63653b1 100644 --- a/vyper/venom/__init__.py +++ b/vyper/venom/__init__.py @@ -11,6 +11,7 @@ from vyper.venom.analysis import MemSSA from vyper.venom.analysis.analysis import IRAnalysesCache from vyper.venom.basicblock import IRLabel, IRLiteral +from vyper.venom.check_venom import check_calling_convention from vyper.venom.context import IRContext from vyper.venom.function import IRFunction from vyper.venom.ir_node_to_venom import ir_node_to_venom @@ -121,6 +122,8 @@ def _run_global_passes(ctx: IRContext, optimize: OptimizationLevel, ir_analyses: def run_passes_on(ctx: IRContext, optimize: OptimizationLevel) -> None: ir_analyses = {} + # Validate calling convention invariants before running passes + check_calling_convention(ctx) for fn in ctx.functions.values(): ir_analyses[fn] = IRAnalysesCache(fn) diff --git a/vyper/venom/analysis/available_expression.py b/vyper/venom/analysis/available_expression.py index 7cc6390797..72c73969cf 100644 --- a/vyper/venom/analysis/available_expression.py +++ b/vyper/venom/analysis/available_expression.py @@ -288,6 +288,8 @@ def _handle_bb(self, bb: IRBasicBlock) -> bool: for inst in bb.instructions: if inst.opcode == "assign" or inst.is_pseudo or inst.is_bb_terminator: continue + if inst.num_outputs > 1: + continue if ( inst not in self.inst_to_available @@ -341,6 +343,9 @@ def _get_operand( # source is a magic opcode for tests if inst.opcode == "source": return op + # instructions with multiple outputs currently can't be mapped an expression + if inst.num_outputs > 1: + return op assert inst in self.inst_to_expr, f"operand source was not handled, ({op}, {inst})" return self.inst_to_expr[inst] diff --git a/vyper/venom/analysis/dfg.py b/vyper/venom/analysis/dfg.py index 4a3ce0aaff..1fa747e926 100644 --- a/vyper/venom/analysis/dfg.py +++ b/vyper/venom/analysis/dfg.py @@ -84,8 +84,7 @@ def analyze(self): inputs = self._dfg_inputs.setdefault(op, OrderedSet()) inputs.add(inst) - for op in res: # type: ignore - assert isinstance(op, IRVariable) + for op in res: self._dfg_outputs[op] = inst def as_graph(self) -> str: @@ -96,8 +95,7 @@ def as_graph(self) -> str: for var, inputs in self._dfg_inputs.items(): for input in inputs: for op in input.get_outputs(): - if isinstance(op, IRVariable): - lines.append(f' " {var.name} " -> " {op.name} "') + lines.append(f' " {var.name} " -> " {op.name} "') lines.append("}") return "\n".join(lines) diff --git a/vyper/venom/analysis/stack_order.py b/vyper/venom/analysis/stack_order.py index 01c671ba72..d7fd4a0378 100644 --- a/vyper/venom/analysis/stack_order.py +++ b/vyper/venom/analysis/stack_order.py @@ -71,8 +71,7 @@ def analyze_bb(self, bb: IRBasicBlock) -> Needed: inst.operands, ) self.stack = self.stack[: -len(inst.operands)] - if inst.output is not None: - self.stack.append(inst.output) + self.stack.extend(inst.get_outputs()) for pred in self.cfg.cfg_in(bb): self._from_to[(pred, bb)] = self.needed.copy() @@ -97,7 +96,7 @@ def from_to(self, origin: IRBasicBlock, successor: IRBasicBlock) -> Needed: def _handle_assign(self, inst: IRInstruction): assert inst.opcode == "assign" - assert inst.output is not None + _ = inst.output # Assert single output index = inst.parent.instructions.index(inst) next_inst = inst.parent.instructions[index + 1] diff --git a/vyper/venom/analysis/var_definition.py b/vyper/venom/analysis/var_definition.py index c8f8c46c7c..9abe3265b2 100644 --- a/vyper/venom/analysis/var_definition.py +++ b/vyper/venom/analysis/var_definition.py @@ -51,8 +51,8 @@ def _handle_bb(self, bb: IRBasicBlock) -> bool: for inst in bb.instructions: self.defined_vars[inst] = bb_defined.copy() - if inst.output is not None: - bb_defined.add(inst.output) + outs = inst.get_outputs() + bb_defined.addmany(outs) if self.defined_vars_bb[bb] != bb_defined: self.defined_vars_bb[bb] = bb_defined diff --git a/vyper/venom/basicblock.py b/vyper/venom/basicblock.py index 1b72094dd7..a25c669c46 100644 --- a/vyper/venom/basicblock.py +++ b/vyper/venom/basicblock.py @@ -229,7 +229,7 @@ class IRInstruction: opcode: str operands: list[IROperand] - output: Optional[IRVariable] + _outputs: list[IRVariable] parent: IRBasicBlock annotation: Optional[str] ast_source: Optional[IRnode] @@ -239,13 +239,14 @@ def __init__( self, opcode: str, operands: list[IROperand] | Iterator[IROperand], - output: Optional[IRVariable] = None, + outputs: Optional[list[IRVariable]] = None, ): assert isinstance(opcode, str), "opcode must be an str" assert isinstance(operands, list | Iterator), "operands must be a list" self.opcode = opcode self.operands = list(operands) # in case we get an iterator - self.output = output + self._outputs = list(outputs) if outputs is not None else [] + self.annotation = None self.ast_source = None self.error_msg = None @@ -311,18 +312,45 @@ def get_input_variables(self) -> Iterator[IRVariable]: """ return (op for op in self.operands if isinstance(op, IRVariable)) - def get_outputs(self) -> list[IROperand]: + def get_outputs(self) -> list[IRVariable]: + """ + Get the outputs of the instruction. + Makes a copy to prevent external mutation, so + keep that in mind when performance matters. + """ + return list(self._outputs) + + @property + def num_outputs(self) -> int: + """ + Return how many outputs this instruction produces. + """ + return len(self._outputs) + + @property + def output(self) -> IRVariable: + """ + Return the single output for instructions with exactly one. + """ + assert len(self._outputs) == 1, f"expected single output for {self}" + return self._outputs[0] + + def has_outputs(self) -> bool: """ - Get the output item for an instruction. - (Currently all instructions output at most one item, but write - it as a list to be generic for the future) + Check whether this instruction produces any outputs. """ - return [self.output] if self.output else [] + return len(self._outputs) > 0 + + def set_outputs(self, outputs: list[IRVariable]) -> None: + """ + Replace all outputs for this instruction. + """ + self._outputs = list(outputs) def make_nop(self): self.annotation = str(self) # Keep original instruction as annotation for debugging self.opcode = "nop" - self.output = None + self._outputs = [] self.operands = [] def flip(self): @@ -398,7 +426,7 @@ def get_ast_source(self) -> Optional[IRnode]: return self.parent.parent.ast_source def copy(self) -> IRInstruction: - ret = IRInstruction(self.opcode, self.operands.copy(), self.output) + ret = IRInstruction(self.opcode, self.operands.copy(), self.get_outputs()) ret.annotation = self.annotation ret.ast_source = self.ast_source ret.error_msg = self.error_msg @@ -406,8 +434,9 @@ def copy(self) -> IRInstruction: def str_short(self) -> str: s = "" - if self.output: - s += f"{self.output} = " + outs = self.get_outputs() + if len(outs) > 0: + s += f"{', '.join(map(str, outs))} = " opcode = f"{self.opcode} " if self.opcode != "assign" else "" s += opcode operands = self.operands @@ -418,8 +447,9 @@ def str_short(self) -> str: def __repr__(self) -> str: s = "" - if self.output: - s += f"{self.output} = " + outs = self.get_outputs() + if len(outs) > 0: + s += f"{', '.join(map(str, outs))} = " opcode = f"{self.opcode} " if self.opcode != "assign" else "" s += opcode operands = self.operands @@ -506,13 +536,17 @@ def append_instruction( """ assert not self.is_terminated, self - if ret is None: - ret = self.parent.get_next_variable() if opcode not in NO_OUTPUT_INSTRUCTIONS else None + if ret is None and opcode not in NO_OUTPUT_INSTRUCTIONS: + ret = self.parent.get_next_variable() # Wrap raw integers in IRLiterals inst_args = [_ir_operand_from_value(arg) for arg in args] - inst = IRInstruction(opcode, inst_args, ret) + outputs = None + if ret is not None: + outputs = [ret] + + inst = IRInstruction(opcode, inst_args, outputs) inst.parent = self inst.ast_source = self.parent.ast_source inst.error_msg = self.parent.error_msg @@ -521,27 +555,27 @@ def append_instruction( return ret def append_invoke_instruction( - self, args: list[IROperand | int], returns: bool - ) -> Optional[IRVariable]: + self, args: list[IROperand | int], returns: int = 0 + ) -> list[IRVariable]: """ - Append an invoke to the basic block + Append an invoke to the basic block. Always returns a list of output variables. """ assert not self.is_terminated, self - ret = None - if returns: - ret = self.parent.get_next_variable() + + # Determine outputs + outputs: list[IRVariable] = [self.parent.get_next_variable() for _ in range(returns)] # Wrap raw integers in IRLiterals inst_args = [_ir_operand_from_value(arg) for arg in args] assert isinstance(inst_args[0], IRLabel), "Invoked non label" - inst = IRInstruction("invoke", inst_args, ret) + inst = IRInstruction("invoke", inst_args, outputs) inst.parent = self inst.ast_source = self.parent.ast_source inst.error_msg = self.parent.error_msg self.instructions.append(inst) - return ret + return outputs def insert_instruction(self, instruction: IRInstruction, index: Optional[int] = None) -> None: assert isinstance(instruction, IRInstruction), "instruction must be an IRInstruction" @@ -626,7 +660,11 @@ def get_assignments(self): """ Get all assignments in basic block. """ - return [inst.output for inst in self.instructions if inst.output] + ret: list[IRVariable] = [] + for inst in self.instructions: + outs = inst.get_outputs() + ret.extend(outs) + return ret def get_uses(self) -> dict[IRVariable, OrderedSet[IRInstruction]]: uses: dict[IRVariable, OrderedSet[IRInstruction]] = {} diff --git a/vyper/venom/check_venom.py b/vyper/venom/check_venom.py index 0dc198d2b3..7f85924245 100644 --- a/vyper/venom/check_venom.py +++ b/vyper/venom/check_venom.py @@ -1,5 +1,5 @@ from vyper.venom.analysis import IRAnalysesCache, VarDefinition -from vyper.venom.basicblock import IRBasicBlock, IRVariable +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRVariable from vyper.venom.context import IRContext from vyper.venom.function import IRFunction @@ -30,6 +30,49 @@ def __str__(self): return f"var {self.var} not defined:\n {self.inst}\n\n{bb}" +class InconsistentReturnArity(VenomError): + message: str = "function has inconsistent return arity" + + def __init__(self, function: IRFunction, arities: set[int]): + self.function = function + self.arities = arities + + def __str__(self): + return ( + f"function {self.function.name} has inconsistent 'ret' arities: {sorted(self.arities)}" + ) + + +class InvokeArityMismatch(VenomError): + message: str = "invoke outputs do not match callee return arity" + + def __init__(self, caller: IRFunction, inst: IRInstruction, expected: int, got: int): + self.caller = caller + self.inst = inst + self.expected = expected + self.got = got + + def __str__(self): + bb = self.inst.parent + return ( + f"invoke arity mismatch in {self.caller.name}: " + f"expected {self.expected}, got {self.got}\n" + f" {self.inst}\n\n{bb}" + ) + + +class MultiOutputNonInvoke(VenomError): + message: str = "multi-output assignment only supported for invoke" + + def __init__(self, caller: IRFunction, inst: IRInstruction): + self.caller = caller + self.inst = inst + + def __str__(self): + bb = self.inst.parent + return f"multi-output on non-invoke in {self.caller.name}:\n" f" {self.inst}\n\n{bb}" + + def _handle_var_definition( fn: IRFunction, bb: IRBasicBlock, var_def: VarDefinition ) -> list[VenomError]: @@ -68,12 +111,71 @@ def find_semantic_errors_fn(fn: IRFunction) -> list[VenomError]: return errors +def _collect_ret_arities(context: IRContext) -> dict[IRFunction, set[int]]: + ret_arities: dict[IRFunction, set[int]] = {} + for fn in context.functions.values(): + arities: set[int] = set() + for bb in fn.get_basic_blocks(): + for inst in bb.instructions: + if inst.opcode == "ret": + # last operand is return PC; all preceding (if any) are return values + arities.add(len(inst.operands) - 1) + + ret_arities[fn] = arities + + return ret_arities + + +def find_calling_convention_errors(context: IRContext) -> list[VenomError]: + errors: list[VenomError] = [] + + # Enforce invoke binding exactly callee arity + ret_arities = _collect_ret_arities(context) + + for fn, arities in ret_arities.items(): + if len(arities) > 1: + errors.append(InconsistentReturnArity(fn, arities)) + + for caller in context.functions.values(): + for bb in caller.get_basic_blocks(): + for inst in bb.instructions: + # Disallow multi-output except on invoke + got_num = inst.num_outputs + if got_num > 1 and inst.opcode != "invoke": + errors.append(MultiOutputNonInvoke(caller, inst)) + continue + if inst.opcode != "invoke": + continue + target = inst.operands[0] + assert isinstance(target, IRLabel) + callee = context.get_function(target) + arities = ret_arities[callee] + + if len(arities) == 0: + expected_num = 0 + elif len(arities) == 1: + expected_num = next(iter(arities)) + else: + # a function with InconsistentReturnArity, we already + # checked this above + continue + + if got_num != expected_num: + errors.append(InvokeArityMismatch(caller, inst, expected_num, got_num)) + + return errors + + def find_semantic_errors(context: IRContext) -> list[VenomError]: errors: list[VenomError] = [] + # Per-function basic checks (var definitions, bb termination, etc.) for fn in context.functions.values(): errors.extend(find_semantic_errors_fn(fn)) + # Calling convention errors can be reported too if desired + errors.extend(find_calling_convention_errors(context)) + return errors @@ -82,3 +184,9 @@ def check_venom_ctx(context: IRContext): if errors: raise ExceptionGroup("venom semantic errors", errors) + + +def check_calling_convention(context: IRContext): + errors = find_calling_convention_errors(context) + if errors: + raise ExceptionGroup("venom calling convention errors", errors) diff --git a/vyper/venom/function.py b/vyper/venom/function.py index c57229eabc..59cdaa6b05 100644 --- a/vyper/venom/function.py +++ b/vyper/venom/function.py @@ -118,8 +118,9 @@ def freshen_varnames(self) -> None: varmap: dict[IRVariable, IRVariable] = defaultdict(self.get_next_variable) for bb in self.get_basic_blocks(): for inst in bb.instructions: - if inst.output: - inst.output = varmap[inst.output] + all_outputs = inst.get_outputs() + if len(all_outputs) > 0: + inst.set_outputs([varmap[o] for o in all_outputs]) for i, op in enumerate(inst.operands): if not isinstance(op, IRVariable): diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 3e3842540b..a1beb6c2ef 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -6,6 +6,7 @@ from typing import Optional from vyper.codegen.context import Alloca +from vyper.codegen.core import is_tuple_like from vyper.codegen.ir_node import IRnode from vyper.evm.opcodes import get_opcodes from vyper.ir.compile_ir import _runtime_code_offsets @@ -20,7 +21,9 @@ from vyper.venom.context import IRContext from vyper.venom.function import IRFunction, IRParameter -ENABLE_NEW_CALL_CONV = True +# Experimental: allow returning multiple 32-byte values via the stack +ENABLE_MULTI_RETURNS = True +MAX_STACK_RETURNS = 2 MAX_STACK_ARGS = 6 # Instructions that are mapped to their inverse @@ -180,9 +183,6 @@ def _append_return_args(fn: IRFunction, ofst: int = 0, size: int = 0): def _pass_via_stack(func_t) -> dict[str, bool]: # returns a dict which returns True if a given argument (referred to # by name) should be passed via the stack - if not ENABLE_NEW_CALL_CONV: - return {arg.name: False for arg in func_t.arguments} - arguments = {arg.name: arg for arg in func_t.arguments} stack_items = 0 @@ -211,7 +211,7 @@ def _handle_self_call(fn: IRFunction, ir: IRnode, symbols: SymbolTable) -> Optio func_t = ir.passthrough_metadata["func_t"] assert func_t is not None, "func_t not found in passthrough metadata" - returns_word = _returns_word(func_t) + returns_count = _returns_stack_count(func_t) if setup_ir != goto_ir: _convert_ir_bb(fn, setup_ir, symbols) @@ -223,35 +223,46 @@ def _handle_self_call(fn: IRFunction, ir: IRnode, symbols: SymbolTable) -> Optio callsite = callsite_op.value bb = fn.get_basic_block() - return_buf = None - + return_buf: Optional[IROperand] = None + # If a return buffer pointer is supplied by upstream IR, use it if len(converted_args) > 1: return_buf = converted_args[0] + # For multi-return via stack without a provided buffer, synthesize one + if returns_count > 0 and return_buf is None: + tmp_buf = bb.append_instruction("alloca", 0, 32 * returns_count, get_scratch_alloca_id()) + assert tmp_buf is not None + return_buf = tmp_buf stack_args: list[IROperand] = [IRLabel(str(target_label))] if return_buf is not None: - if not ENABLE_NEW_CALL_CONV or not returns_word: + if returns_count == 0: stack_args.append(return_buf) # type: ignore callsite_args = _callsites[callsite] - if ENABLE_NEW_CALL_CONV: - for alloca in callsite_args: - if not _pass_via_stack(func_t)[alloca.name]: - continue - ptr = _alloca_table[alloca._id] - stack_arg = bb.append_instruction("mload", ptr) - assert stack_arg is not None - stack_args.append(stack_arg) - - if returns_word: - ret_value = bb.append_invoke_instruction(stack_args, returns=True) # type: ignore - assert ret_value is not None - assert isinstance(return_buf, IROperand) - bb.append_instruction("mstore", ret_value, return_buf) - return return_buf - - bb.append_invoke_instruction(stack_args, returns=False) # type: ignore + for alloca in callsite_args: + if not _pass_via_stack(func_t)[alloca.name]: + continue + ptr = _alloca_table[alloca._id] + stack_arg = bb.append_instruction("mload", ptr) + assert stack_arg is not None + stack_args.append(stack_arg) + + if returns_count > 0: + outs = bb.append_invoke_instruction(stack_args, returns=returns_count) # type: ignore + assert isinstance(return_buf, IROperand) + for i, outv in enumerate(outs): + if i == 0: + dst = return_buf + else: + ofst = bb.append_instruction("assign", IRLiteral(32 * i)) + assert ofst is not None + dst = bb.append_instruction("add", return_buf, ofst) # type: ignore + assert dst is not None + bb.append_instruction("mstore", outv, dst) + return return_buf + + bb.append_invoke_instruction(stack_args, returns=0) # type: ignore return return_buf @@ -270,6 +281,18 @@ def _returns_word(func_t) -> bool: return return_t is not None and _is_word_type(return_t) +def _returns_stack_count(func_t) -> int: + ret_t = func_t.return_type + if ret_t is None: + return 0 + if ENABLE_MULTI_RETURNS and is_tuple_like(ret_t): + members = ret_t.tuple_items() + if 1 <= len(members) <= MAX_STACK_RETURNS and all(_is_word_type(t) for (_k, t) in members): + return len(members) + return 0 + return 1 if _is_word_type(ret_t) else 0 + + def _handle_internal_func( # TODO: remove does_return_data, replace with `func_t.return_type is not None` fn: IRFunction, @@ -295,17 +318,18 @@ def _handle_internal_func( _saved_alloca_table = _alloca_table _alloca_table = {} - returns_word = _returns_word(func_t) + returns_count = _returns_stack_count(func_t) # return buffer if does_return_data: - if ENABLE_NEW_CALL_CONV and returns_word: + if returns_count > 0: # TODO: remove this once we have proper memory allocator # functionality in venom. Currently, we hardcode the scratch - # buffer size of 32 bytes. + # buffer size of up to 32 * MAX_STACK_RETURNS (2) bytes. # TODO: we don't need to use scratch space once the legacy optimizer # is disabled. - buf = bb.append_instruction("alloca", 0, 32, get_scratch_alloca_id()) + # allocate scratch return buffer sized to the number of stack-returned words + buf = bb.append_instruction("alloca", 0, 32 * returns_count, get_scratch_alloca_id()) else: buf = bb.append_instruction("param") bb.instructions[-1].annotation = "return_buffer" @@ -313,32 +337,31 @@ def _handle_internal_func( assert buf is not None # help mypy symbols["return_buffer"] = buf - if ENABLE_NEW_CALL_CONV: - stack_index = 0 - if func_t.return_type is not None and not _returns_word(func_t): - stack_index += 1 - for arg in func_t.arguments: - if not _pass_via_stack(func_t)[arg.name]: - continue - - param = bb.append_instruction("param") - bb.instructions[-1].annotation = arg.name - assert param is not None # help mypy - - var = context.lookup_var(arg.name) - - venom_arg = IRParameter( - name=var.name, - index=stack_index, - offset=var.alloca.offset, - size=var.alloca.size, - id_=var.alloca._id, - call_site_var=None, - func_var=param, - addr_var=None, - ) - fn.args.append(venom_arg) - stack_index += 1 + stack_index = 0 + if func_t.return_type is not None and _returns_stack_count(func_t) == 0: + stack_index += 1 + for arg in func_t.arguments: + if not _pass_via_stack(func_t)[arg.name]: + continue + + param = bb.append_instruction("param") + bb.instructions[-1].annotation = arg.name + assert param is not None # help mypy + + var = context.lookup_var(arg.name) + + venom_arg = IRParameter( + name=var.name, + index=stack_index, + offset=var.alloca.offset, + size=var.alloca.size, + id_=var.alloca._id, + call_site_var=None, + func_var=param, + addr_var=None, + ) + fn.args.append(venom_arg) + stack_index += 1 # return address return_pc = bb.append_instruction("param") @@ -573,10 +596,19 @@ def _convert_ir_bb(fn, ir, symbols): if label.value == "return_pc": label = symbols.get("return_pc") # return label should be top of stack - if _returns_word(_current_func_t) and ENABLE_NEW_CALL_CONV: + k = _returns_stack_count(_current_func_t) + if k > 0: buf = symbols["return_buffer"] - val = bb.append_instruction("mload", buf) - bb.append_instruction("ret", val, label) + ret_vals: list[IROperand] = [] + for i in range(k): + if i == 0: + ptr = buf + else: + ofst = bb.append_instruction("assign", IRLiteral(32 * i)) + ptr = bb.append_instruction("add", buf, ofst) + val = bb.append_instruction("mload", ptr) + ret_vals.append(val) # type: ignore[arg-type] + bb.append_instruction("ret", *ret_vals, label) else: bb.append_instruction("ret", label) @@ -662,7 +694,7 @@ def emit_body_blocks(): fn.append_basic_block(incr_block) incr_block.insert_instruction( - IRInstruction("add", [counter_var, IRLiteral(1)], counter_var) + IRInstruction("add", [counter_var, IRLiteral(1)], [counter_var]) ) incr_block.append_instruction("jmp", cond_block.label) @@ -702,7 +734,7 @@ def emit_body_blocks(): bb = fn.get_basic_block() ptr = bb.append_instruction("palloca", alloca.offset, alloca.size, alloca._id) bb.instructions[-1].annotation = f"{alloca.name} (memory)" - if ENABLE_NEW_CALL_CONV and _pass_via_stack(_current_func_t)[alloca.name]: + if _pass_via_stack(_current_func_t)[alloca.name]: param = fn.get_param_by_id(alloca._id) assert param is not None bb.append_instruction("mstore", param.func_var, ptr) @@ -716,7 +748,7 @@ def emit_body_blocks(): bb = fn.get_basic_block() callsite_func = ir.passthrough_metadata["callsite_func"] - if ENABLE_NEW_CALL_CONV and _pass_via_stack(callsite_func)[alloca.name]: + if _pass_via_stack(callsite_func)[alloca.name]: ptr = bb.append_instruction("alloca", alloca.offset, alloca.size, alloca._id) else: # if we use alloca, mstores might get removed. convert diff --git a/vyper/venom/parser.py b/vyper/venom/parser.py index 55d27622c8..c0e198d314 100644 --- a/vyper/venom/parser.py +++ b/vyper/venom/parser.py @@ -36,7 +36,9 @@ label_decl: (IDENT | ESCAPED_STRING) ":" NEWLINE+ statement: (assignment | instruction) NEWLINE+ - assignment: VAR_IDENT "=" expr + assignment: lhs "=" expr + lhs: VAR_IDENT | lhs_list + lhs_list: VAR_IDENT ("," VAR_IDENT)+ expr: instruction | operand instruction: IDENT operands_list? @@ -71,13 +73,12 @@ def _set_last_var(fn: IRFunction): for bb in fn.get_basic_blocks(): for inst in bb.instructions: - if inst.output is None: - continue - value = inst.output.value - assert value.startswith("%") - varname = value[1:] - if varname.isdigit(): - fn.last_variable = max(fn.last_variable, int(varname)) + for output in inst.get_outputs(): + value = output.value + assert value.startswith("%") + varname = value[1:] + if varname.isdigit(): + fn.last_variable = max(fn.last_variable, int(varname)) def _set_last_label(ctx: IRContext): @@ -206,13 +207,32 @@ def data_item(self, children) -> DataItem: item = item.replace("_", "") return DataItem(bytes.fromhex(item)) + def lhs(self, children): + # unwrap VAR_IDENT or lhs_list + assert len(children) == 1 + return children[0] + + def lhs_list(self, children): + # list of VAR_IDENTs + return children + def assignment(self, children) -> IRInstruction: - to, value = children + left, value = children + # Multi-output assignment (e.g., %a, %b = invoke @f) + if isinstance(left, list): + if not isinstance(value, IRInstruction): + raise TypeError("Multi-target assignment requires an instruction on RHS") + outs = left + value.set_outputs(outs) + return value + + # Single-target assignment + to = left if isinstance(value, IRInstruction): - value.output = to + value.set_outputs([to]) return value if isinstance(value, (IRLiteral, IRVariable, IRLabel)): - return IRInstruction("assign", [value], output=to) + return IRInstruction("assign", [value], [to]) raise TypeError(f"Unexpected value {value} of type {type(value)}") def expr(self, children) -> IRInstruction | IROperand: diff --git a/vyper/venom/passes/algebraic_optimization.py b/vyper/venom/passes/algebraic_optimization.py index e04dc02fc8..a4a7cd6be7 100644 --- a/vyper/venom/passes/algebraic_optimization.py +++ b/vyper/venom/passes/algebraic_optimization.py @@ -55,8 +55,8 @@ def _optimize_iszero_chains(self) -> None: if iszero_count == 0: continue - assert isinstance(inst.output, IRVariable) - for use_inst in self.dfg.get_uses(inst.output).copy(): + inst_out = inst.output + for use_inst in self.dfg.get_uses(inst_out).copy(): opcode = use_inst.opcode if opcode == "iszero": @@ -75,7 +75,7 @@ def _optimize_iszero_chains(self) -> None: continue out_var = iszero_chain[keep_count].operands[0] - self.updater.update_operands(use_inst, {inst.output: out_var}) + self.updater.update_operands(use_inst, {inst_out: out_var}) def _get_iszero_chain(self, op: IROperand) -> list[IRInstruction]: chain: list[IRInstruction] = [] @@ -123,8 +123,9 @@ def _flip_inst(self, inst: IRInstruction): # "peephole", weakening algebraic optimizations def _handle_inst_peephole(self, inst: IRInstruction): - if inst.output is None: + if inst.num_outputs != 1: return + inst_out = inst.output if inst.is_volatile: return if inst.opcode == "assign": @@ -243,8 +244,7 @@ def _handle_inst_peephole(self, inst: IRInstruction): return return - assert inst.output is not None - uses = self.dfg.get_uses(inst.output) + uses = self.dfg.get_uses(inst_out) is_truthy = all(i.opcode in TRUTHY_INSTRUCTIONS for i in uses) prefer_iszero = all(i.opcode in ("assert", "iszero") for i in uses) @@ -302,7 +302,7 @@ def _handle_inst_peephole(self, inst: IRInstruction): def _optimize_comparator_instruction(self, inst, prefer_iszero): opcode, operands = inst.opcode, inst.operands assert opcode in COMPARATOR_INSTRUCTIONS # sanity - assert isinstance(inst.output, IRVariable) # help mypy + inst_out = inst.output # (x > x) == (x < x) -> 0 if operands[0] == operands[1]: @@ -359,8 +359,7 @@ def _optimize_comparator_instruction(self, inst, prefer_iszero): # rewrite comparisons by either inserting or removing an `iszero`, # e.g. `x > N` -> `x >= (N + 1)` - assert inst.output is not None - uses = self.dfg.get_uses(inst.output) + uses = self.dfg.get_uses(inst_out) if len(uses) != 1: return @@ -401,9 +400,8 @@ def _optimize_comparator_instruction(self, inst, prefer_iszero): if insert_iszero: # next instruction is an assert, so we insert an iszero so # that there will be two iszeros in the assembly. - assert inst.output is not None, inst assert len(after.operands) == 1, after - var = self.updater.add_before(after, "iszero", [inst.output]) + var = self.updater.add_before(after, "iszero", [inst_out]) self.updater.update_operands(after, {after.operands[0]: var}) else: # remove the iszero! diff --git a/vyper/venom/passes/cfg_normalization.py b/vyper/venom/passes/cfg_normalization.py index a801ae91dc..25b9f42444 100644 --- a/vyper/venom/passes/cfg_normalization.py +++ b/vyper/venom/passes/cfg_normalization.py @@ -73,7 +73,7 @@ def _insert_split_basicblock(self, bb: IRBasicBlock, pred_bb: IRBasicBlock) -> I def _needs_forwarding_store(self, var: IRVariable, pred_bb: IRBasicBlock) -> bool: for inst in pred_bb.instructions: - if inst.output == var: + if var in inst.get_outputs(): # variable defined by phi needs forwarding return inst.opcode == "phi" # variable not defined in predecessor needs forwarding diff --git a/vyper/venom/passes/common_subexpression_elimination.py b/vyper/venom/passes/common_subexpression_elimination.py index 9735832d8f..7ead122b80 100644 --- a/vyper/venom/passes/common_subexpression_elimination.py +++ b/vyper/venom/passes/common_subexpression_elimination.py @@ -4,7 +4,7 @@ ) from vyper.venom.analysis.dfg import DFGAnalysis from vyper.venom.analysis.liveness import LivenessAnalysis -from vyper.venom.basicblock import IRInstruction, IRVariable +from vyper.venom.basicblock import IRInstruction from vyper.venom.passes.base_pass import IRPass # instruction that dont need to be stored in available expression @@ -42,6 +42,15 @@ class CSE(IRPass): + """ + Common Subexpression Elimination pass. + + NOTE: This pass does not support instructions with multiple outputs. Currently, + only `invoke` instructions have multiple outputs, and they are excluded from + substitution. But eventually, this pass should be extended to support multi-output + to support the folding of pure functions that return multiple values, etc. + """ + expression_analysis: AvailableExpressionAnalysis def run_pass(self): @@ -72,6 +81,9 @@ def _find_replaceble(self) -> dict[IRInstruction, IRInstruction]: continue if inst.opcode in NONIDEMPOTENT_INSTRUCTIONS: continue + # skip multi-output instructions for now (not supported yet) + if inst.num_outputs > 1: + continue state = self.expression_analysis.get_expression(inst) if state is None: continue @@ -95,9 +107,8 @@ def _replace(self, replace_dict: dict[IRInstruction, IRInstruction]): self._replace_inst(orig, to) def _replace_inst(self, orig_inst: IRInstruction, to_inst: IRInstruction): - if orig_inst.output is not None: + if orig_inst.has_outputs(): orig_inst.opcode = "assign" - assert isinstance(to_inst.output, IRVariable), f"not var {to_inst}" orig_inst.operands = [to_inst.output] else: orig_inst.opcode = "nop" diff --git a/vyper/venom/passes/dead_store_elimination.py b/vyper/venom/passes/dead_store_elimination.py index f00eb792fe..15ec5104f6 100644 --- a/vyper/venom/passes/dead_store_elimination.py +++ b/vyper/venom/passes/dead_store_elimination.py @@ -37,7 +37,7 @@ def _has_uses(self, inst: IRInstruction): """ Checks if the instruction's output is used in the DFG. """ - return inst.output is not None and len(self.dfg.get_uses(inst.output)) > 0 + return any(len(self.dfg.get_uses(output)) > 0 for output in inst.get_outputs()) def _is_memory_def_live(self, query_def: MemoryDef) -> bool: """ diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index 175d6010cf..8d1b37fe70 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -88,18 +88,37 @@ def cost(x: IRInstruction) -> int | float: # indirect data dependencies (offspring of operands) # direct data dependencies (order of operands) - if (x not in self.dda[inst] and x in self.eda[inst]) or inst.flippable: - ret = -1 * int(len(self.data_offspring[x]) > 0) - elif x.output in inst.operands: - assert x in self.dda[inst] # sanity check - assert x.output is not None # help mypy - ret = inst.operands.index(x.output) + len(self.order) - else: - assert x in self.dda[inst] # sanity check - assert x.output in self.order - assert x.output is not None # help mypy - ret = self.order.index(x.output) - return ret + is_effect_only = x not in self.dda[inst] and x in self.eda[inst] + if is_effect_only or inst.flippable: + has_data_offspring = len(self.data_offspring[x]) > 0 + return -1 if has_data_offspring else 0 + + assert x in self.dda[inst] # sanity check + + # locate operands that are produced by x and prefer earliest match + operand_idxs = [ + i + for i, op in enumerate(inst.operands) + if self.dfg.get_producing_instruction(op) is x + ] + if len(operand_idxs) > 0: + return min(operand_idxs) + len(self.order) + + outputs = x.get_outputs() + operand_positions = [ + inst.operands.index(out_var) for out_var in outputs if out_var in inst.operands + ] + if len(operand_positions) > 0: + return min(operand_positions) + len(self.order) + + order_positions = [ + self.order.index(out_var) for out_var in outputs if out_var in self.order + ] + if len(order_positions) > 0: + return min(order_positions) + + # fall back to a stable default when no operand is associated + return len(self.order) # heuristic: sort by size of child dependency graph orig_children = children.copy() diff --git a/vyper/venom/passes/function_inliner.py b/vyper/venom/passes/function_inliner.py index 6d6f355d2b..eabcc67b0c 100644 --- a/vyper/venom/passes/function_inliner.py +++ b/vyper/venom/passes/function_inliner.py @@ -7,7 +7,6 @@ from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IROperand, IRVariable from vyper.venom.context import IRContext from vyper.venom.function import IRFunction -from vyper.venom.ir_node_to_venom import ENABLE_NEW_CALL_CONV from vyper.venom.passes import FloatAllocas from vyper.venom.passes.base_pass import IRGlobalPass @@ -115,7 +114,6 @@ def _inline_function(self, func: IRFunction, call_sites: List[IRInstruction]) -> # this can happen when we have a->b->c and a->c, # and both b and c get inlined. calloca_inst = callocas[alloca_id] - assert calloca_inst.output is not None inst.opcode = "assign" inst.operands = [calloca_inst.output] else: @@ -131,7 +129,6 @@ def _inline_function(self, func: IRFunction, call_sites: List[IRInstruction]) -> continue inst.opcode = "assign" calloca_inst = callocas[alloca_id] - assert calloca_inst.output is not None # help mypy inst.operands = [calloca_inst.output] found.add(alloca_id) @@ -144,7 +141,7 @@ def _inline_function(self, func: IRFunction, call_sites: List[IRInstruction]) -> # demote to alloca so that mem2var will work inst.opcode = "alloca" - def _inline_call_site(self, func: IRFunction, call_site: IRInstruction): + def _inline_call_site(self, func: IRFunction, call_site: IRInstruction) -> None: """ Inline function into call site. """ @@ -190,16 +187,28 @@ def _inline_call_site(self, func: IRFunction, call_site: IRInstruction): # will be handled at the toplevel `inline_function` pass elif inst.opcode == "ret": - if len(inst.operands) > 1: - # sanity check (should remove once new callconv stabilizes) - assert ENABLE_NEW_CALL_CONV - ret_value = inst.operands[0] - bb.insert_instruction( - IRInstruction("assign", [ret_value], call_site.output), -1 - ) + # ret may be: ret @return_pc OR ret v1, v2, ..., @return_pc + # The last operand is the return PC (label or variable); + # all preceding operands (if any) are return values. + ret_values = [op for op in inst.operands[:-1] if not isinstance(op, IRLabel)] + inst.opcode = "jmp" inst.operands = [call_site_return.label] + # Map each returned value to corresponding callsite outputs + if len(ret_values) == 0: + continue + + callsite_outs = call_site.get_outputs() + assert len(ret_values) == len( + callsite_outs + ), f"Return arity mismatch: {len(ret_values)} vs {len(callsite_outs)}" + for idx, ret_value in enumerate(ret_values): + target_out = callsite_outs[idx] + bb.insert_instruction( + IRInstruction("assign", [ret_value], [target_out]), -1 + ) + for inst in bb.instructions: if not inst.annotation: inst.annotation = f"from {func.name}" @@ -280,11 +289,10 @@ def _clone_instruction(self, inst: IRInstruction, prefix: str) -> IRInstruction: else: ops.append(op) - output = None - if inst.output: - output = IRVariable(f"{prefix}{inst.output.plain_name}") + all_outputs = inst.get_outputs() + cloned_outputs = [IRVariable(f"{prefix}{o.plain_name}") for o in all_outputs] - clone = IRInstruction(inst.opcode, ops, output) + clone = IRInstruction(inst.opcode, ops, cloned_outputs) clone.parent = inst.parent clone.annotation = inst.annotation clone.ast_source = inst.ast_source diff --git a/vyper/venom/passes/load_elimination.py b/vyper/venom/passes/load_elimination.py index 1b051e8e03..c93a3fa50e 100644 --- a/vyper/venom/passes/load_elimination.py +++ b/vyper/venom/passes/load_elimination.py @@ -86,7 +86,6 @@ def _handle_bb( if inst.opcode == load_opcode: self.inst_to_lattice[inst] = lattice.copy() ptr = inst.operands[0] - assert inst.output is not None lattice[ptr] = OrderedSet([inst.output]) elif inst.opcode == store_opcode: self.inst_to_lattice[inst] = lattice.copy() @@ -169,8 +168,6 @@ def _handle_load(self, inst): existing_value = self._lattice[inst].get(ptr, OrderedSet()).copy() - assert inst.output is not None # help mypy - if len(existing_value) == 1: self.updater.mk_assign(inst, existing_value.pop()) elif len(existing_value) > 1: diff --git a/vyper/venom/passes/lower_dload.py b/vyper/venom/passes/lower_dload.py index c863a1b7c7..970b0f0328 100644 --- a/vyper/venom/passes/lower_dload.py +++ b/vyper/venom/passes/lower_dload.py @@ -22,7 +22,7 @@ def _handle_bb(self, bb: IRBasicBlock): (ptr,) = inst.operands var = fn.get_next_variable() bb.insert_instruction( - IRInstruction("add", [ptr, IRLabel("code_end")], output=var), index=idx + IRInstruction("add", [ptr, IRLabel("code_end")], [var]), index=idx ) idx += 1 dst = IRLiteral(MemoryPositions.FREE_VAR_SPACE) @@ -36,7 +36,7 @@ def _handle_bb(self, bb: IRBasicBlock): _, src, _ = inst.operands code_ptr = fn.get_next_variable() bb.insert_instruction( - IRInstruction("add", [src, IRLabel("code_end")], output=code_ptr), index=idx + IRInstruction("add", [src, IRLabel("code_end")], [code_ptr]), index=idx ) inst.opcode = "codecopy" inst.operands[1] = code_ptr diff --git a/vyper/venom/passes/machinery/inst_updater.py b/vyper/venom/passes/machinery/inst_updater.py index f8ab8bafb9..9701337375 100644 --- a/vyper/venom/passes/machinery/inst_updater.py +++ b/vyper/venom/passes/machinery/inst_updater.py @@ -24,7 +24,6 @@ def update_operands( # move the uses of old_var to new_inst def move_uses(self, old_var: IRVariable, new_inst: IRInstruction): - assert new_inst.output is not None new_var = new_inst.output for use in list(self.dfg.get_uses(old_var)): @@ -55,19 +54,24 @@ def update( if isinstance(op, IRVariable): self.dfg.add_use(op, inst) + old_outputs = inst.get_outputs() + if opcode in NO_OUTPUT_INSTRUCTIONS: - if inst.output is not None: + for output in old_outputs: assert new_output is None - assert len(uses := self.dfg.get_uses(inst.output)) == 0, (inst, uses) - self.dfg.remove_producing_instruction(inst.output) - inst.output = None + assert len(uses := self.dfg.get_uses(output)) == 0, (inst, uses) + self.dfg.remove_producing_instruction(output) + inst.set_outputs([]) else: # new_output is None is sentinel meaning "no change" - if new_output is not None and new_output != inst.output: - if inst.output is not None: - self.dfg.remove_producing_instruction(inst.output) + if new_output is not None: + old_primary = old_outputs[0] if len(old_outputs) > 0 else None + if old_primary is not None and old_primary != new_output: + self.dfg.remove_producing_instruction(old_primary) self.dfg.set_producing_instruction(new_output, inst) - inst.output = new_output + # multi-output instructions are not currently updated this way + assert len(old_outputs) <= 1 + inst.set_outputs([new_output]) inst.opcode = opcode inst.operands = new_operands @@ -89,10 +93,13 @@ def nop_multi(self, to_nop: Iterable[IRInstruction]): return # NOTE: this doesn't work for dfg cycles. inst = q.popleft() - if inst.output and len(self.dfg.get_uses(inst.output)) > 0: + # Check if ANY output has uses + outputs = inst.get_outputs() + has_uses = any(len(self.dfg.get_uses(output)) > 0 for output in outputs) + if has_uses: q.append(inst) - continue - self.nop(inst) + else: + self.nop(inst) # this should only happen if we try to delete a dfg cycle, cross # that bridge when we get to it. @@ -120,7 +127,7 @@ def add_before( var = inst.parent.parent.get_next_variable() operands = list(args) - new_inst = IRInstruction(opcode, operands, output=var) + new_inst = IRInstruction(opcode, operands, [var] if var is not None else None) inst.parent.insert_instruction(new_inst, index) for op in new_inst.operands: if isinstance(op, IRVariable): diff --git a/vyper/venom/passes/make_ssa.py b/vyper/venom/passes/make_ssa.py index 0ffd33d77b..3097af09b7 100644 --- a/vyper/venom/passes/make_ssa.py +++ b/vyper/venom/passes/make_ssa.py @@ -69,7 +69,7 @@ def _place_phi(self, var: IRVariable, basic_block: IRBasicBlock): args.append(bb.label) # type: ignore args.append(var) # type: ignore - basic_block.insert_instruction(IRInstruction("phi", args, var), 0) + basic_block.insert_instruction(IRInstruction("phi", args, [var]), 0) def latest_version_of(self, var: IRVariable) -> IRVariable: og_var = self.original_vars[var] @@ -103,21 +103,30 @@ def _rename_vars(self, basic_block: IRBasicBlock): inst.operands = new_ops - if inst.output is not None: - v_name = self.original_vars[inst.output].value + outputs = inst.get_outputs() + if len(outputs) == 0: + continue + + new_outputs: list[IRVariable] = [] + for output in outputs: + v_name = self.original_vars[output].value i = self.var_name_counters[v_name] self.var_name_stacks[v_name].append(i) self.var_name_counters[v_name] += 1 - inst.output = self.latest_version_of(inst.output) - outs.append(inst.output) + new_var = self.latest_version_of(output) + new_outputs.append(new_var) + outs.append(new_var) + + inst.set_outputs(new_outputs) for bb in self.cfg.cfg_out(basic_block): for inst in bb.instructions: if inst.opcode != "phi": continue - assert inst.output is not None, inst # phis should have output + # Ensure phi has exactly one output + _ = inst.output for i, op in enumerate(inst.operands): if op == basic_block.label: var = inst.operands[i + 1] @@ -140,8 +149,9 @@ def _remove_degenerate_phis(self, entry: IRBasicBlock): continue new_ops: list[IROperand] = [] + phi_out = inst.output for label, op in inst.phi_operands: - if op == inst.output: + if op == phi_out: continue new_ops.extend([label, op]) new_ops_len = len(new_ops) diff --git a/vyper/venom/passes/mem2var.py b/vyper/venom/passes/mem2var.py index bec17d0393..5b623c4e4f 100644 --- a/vyper/venom/passes/mem2var.py +++ b/vyper/venom/passes/mem2var.py @@ -2,7 +2,6 @@ from vyper.venom.analysis import CFGAnalysis, DFGAnalysis, LivenessAnalysis from vyper.venom.basicblock import IRInstruction, IRVariable from vyper.venom.function import IRFunction -from vyper.venom.ir_node_to_venom import ENABLE_NEW_CALL_CONV from vyper.venom.passes.base_pass import InstUpdater, IRPass @@ -70,17 +69,13 @@ def _process_palloca_var(self, dfg: DFGAnalysis, palloca_inst: IRInstruction, va # some value given to us by the calling convention fn = self.function - if ENABLE_NEW_CALL_CONV: - # it comes as a stack parameter. this (reifying with param based - # on alloca_id) is a bit kludgey, but we will live. - param = fn.get_param_by_id(alloca_id.value) - if param is None: - self.updater.update(palloca_inst, "mload", [ofst], new_output=var) - else: - self.updater.update(palloca_inst, "assign", [param.func_var], new_output=var) - else: - # otherwise, it comes from memory, convert to an mload. + # it comes as a stack parameter. this (reifying with param based + # on alloca_id) is a bit kludgey, but we will live. + param = fn.get_param_by_id(alloca_id.value) + if param is None: self.updater.update(palloca_inst, "mload", [ofst], new_output=var) + else: + self.updater.update(palloca_inst, "assign", [param.func_var], new_output=var) for inst in uses.copy(): if inst.opcode == "mstore": diff --git a/vyper/venom/passes/memmerging.py b/vyper/venom/passes/memmerging.py index 010a4ef335..7adc87fb7c 100644 --- a/vyper/venom/passes/memmerging.py +++ b/vyper/venom/passes/memmerging.py @@ -157,7 +157,6 @@ def _flush_copies( # if the load is used by any instructions besides the ones # we are removing, we can't delete it. (in the future this # may be handled by "remove unused effects" pass). - assert isinstance(inst.output, IRVariable) # help mypy uses = self.dfg.get_uses(inst.output) if not all(use in copy.insts for use in uses): continue @@ -270,7 +269,6 @@ def _barrier_for(copies: list[_Copy]): if len(copies) > 0: _barrier_for(copies) - assert inst.output is not None, inst self._loads[inst.output] = src_op.value elif inst.opcode == "mstore": @@ -415,8 +413,8 @@ def _merge_mstore_dload(self, bb: IRBasicBlock): dload = inst src = dload.operands[0] - assert dload.output is not None - uses = self.dfg.get_uses(dload.output) + dload_out = dload.output + uses = self.dfg.get_uses(dload_out) if len(uses) == 1: mstore: IRInstruction = uses.first() if mstore.opcode != "mstore": @@ -431,7 +429,7 @@ def _merge_mstore_dload(self, bb: IRBasicBlock): # that uses dload. If we would not restrain ourself to basic # block we would have to check if the mstore dominates all of # the other uses - uses_bb = dload.parent.get_uses().get(dload.output, OrderedSet()) + uses_bb = dload.parent.get_uses().get(dload_out, OrderedSet()) if len(uses_bb) == 0: continue @@ -444,10 +442,8 @@ def _merge_mstore_dload(self, bb: IRBasicBlock): var, dst = mstore.operands - if var != dload.output: + if var != dload_out: continue - - assert isinstance(var, IRVariable) # help mypy new_var = bb.parent.get_next_variable() self.updater.add_before(mstore, "dloadbytes", [IRLiteral(32), src, dst]) @@ -455,7 +451,7 @@ def _merge_mstore_dload(self, bb: IRBasicBlock): mload = mstore # clarity - self.updater.move_uses(dload.output, mload) + self.updater.move_uses(dload_out, mload) self.updater.nop(dload) diff --git a/vyper/venom/passes/phi_elimination.py b/vyper/venom/passes/phi_elimination.py index da608be059..82386d43f0 100644 --- a/vyper/venom/passes/phi_elimination.py +++ b/vyper/venom/passes/phi_elimination.py @@ -29,7 +29,6 @@ def _process_phi(self, inst: IRInstruction): src = srcs.pop() if src == inst: return - assert src.output is not None self.updater.mk_assign(inst, src.output) def _calculate_phi_origins(self): diff --git a/vyper/venom/passes/remove_unused_variables.py b/vyper/venom/passes/remove_unused_variables.py index 171e8805b7..64e59d3509 100644 --- a/vyper/venom/passes/remove_unused_variables.py +++ b/vyper/venom/passes/remove_unused_variables.py @@ -74,7 +74,8 @@ def msize_fence(self, inst): return self.instruction_ordering[inst] < self.get_last_msize(bb) def _process_instruction(self, inst): - if inst.output is None: + outputs = inst.get_outputs() + if len(outputs) == 0: return if inst.is_volatile or inst.is_bb_terminator: return @@ -84,9 +85,11 @@ def _process_instruction(self, inst): self._blocked_by_msize.add(inst) return - uses = self.dfg.get_uses(inst.output) - if len(uses) > 0: - return + # Check if ANY output has uses + for output in outputs: + uses = self.dfg.get_uses(output) + if len(uses) > 0: + return for operand in uniq(inst.get_input_variables()): self.dfg.remove_use(operand, inst) diff --git a/vyper/venom/passes/revert_to_assert.py b/vyper/venom/passes/revert_to_assert.py index 69fcb10acd..067585bd87 100644 --- a/vyper/venom/passes/revert_to_assert.py +++ b/vyper/venom/passes/revert_to_assert.py @@ -35,7 +35,7 @@ def _rewrite_jnz(self, pred, revert_bb): cond, then_label, else_label = term.operands if then_label == revert_bb.label: new_cond = self.function.get_next_variable() - iszero_inst = IRInstruction("iszero", [cond], output=new_cond) + iszero_inst = IRInstruction("iszero", [cond], [new_cond]) assert_inst = IRInstruction("assert", [iszero_inst.output]) pred.insert_instruction(iszero_inst, index=-1) pred.insert_instruction(assert_inst, index=-1) diff --git a/vyper/venom/passes/sccp/sccp.py b/vyper/venom/passes/sccp/sccp.py index ee9580c450..5c13b1e563 100644 --- a/vyper/venom/passes/sccp/sccp.py +++ b/vyper/venom/passes/sccp/sccp.py @@ -171,10 +171,11 @@ def _visit_phi(self, inst: IRInstruction): in_vars.append(self._lookup_from_lattice(var)) value = reduce(_meet, in_vars, LatticeEnum.TOP) # type: ignore - assert inst.output in self.lattice, "unreachable" # sanity + inst_out = inst.output + assert inst_out in self.lattice, "unreachable" # sanity - if value != self._lookup_from_lattice(inst.output): - self._set_lattice(inst.output, value) + if value != self._lookup_from_lattice(inst_out): + self._set_lattice(inst_out, value) self._add_ssa_work_items(inst) def _visit_expr(self, inst: IRInstruction): @@ -184,8 +185,9 @@ def _visit_expr(self, inst: IRInstruction): if self.remove_allocas: store_opcodes += ("alloca", "palloca", "calloca") + outputs = inst.get_outputs() + if opcode in store_opcodes: - assert inst.output is not None, inst out = self._eval_from_lattice(inst.operands[0]) self._set_lattice(inst.output, out) self._add_ssa_work_items(inst) @@ -218,8 +220,9 @@ def _visit_expr(self, inst: IRInstruction): elif opcode in ARITHMETIC_OPS: self._eval(inst) else: - if inst.output is not None: - self._set_lattice(inst.output, LatticeEnum.BOTTOM) + if len(outputs) > 0: + for out_var in outputs: + self._set_lattice(out_var, LatticeEnum.BOTTOM) self._add_ssa_work_items(inst) def _eval(self, inst) -> LatticeItem: @@ -230,11 +233,13 @@ def _eval(self, inst) -> LatticeItem: changed. """ + out_var = inst.output + def finalize(ret): # Update the lattice if the value changed - old_val = self.lattice.get(inst.output, LatticeEnum.TOP) + old_val = self.lattice.get(out_var, LatticeEnum.TOP) if old_val != ret: - self.lattice[inst.output] = ret + self.lattice[out_var] = ret self._add_ssa_work_items(inst) return ret @@ -269,8 +274,10 @@ def finalize(ret): return finalize(res) def _add_ssa_work_items(self, inst: IRInstruction): - for target_inst in self.dfg.get_uses(inst.output): # type: ignore - self.work_list.append(SSAWorkListItem(target_inst)) + outputs = inst.get_outputs() + for out in outputs: + for target_inst in self.dfg.get_uses(out): + self.work_list.append(SSAWorkListItem(target_inst)) def _propagate_constants(self): """ diff --git a/vyper/venom/passes/single_use_expansion.py b/vyper/venom/passes/single_use_expansion.py index de911a35cf..2cab3ff3d2 100644 --- a/vyper/venom/passes/single_use_expansion.py +++ b/vyper/venom/passes/single_use_expansion.py @@ -53,7 +53,7 @@ def _process_bb(self, bb): continue var = self.function.get_next_variable() - to_insert = IRInstruction("assign", [op], var) + to_insert = IRInstruction("assign", [op], [var]) bb.insert_instruction(to_insert, index=i) if len(inst.operands) > j: inst.operands[j] = var diff --git a/vyper/venom/venom_to_assembly.py b/vyper/venom/venom_to_assembly.py index 26b9808d33..40719c9ce4 100644 --- a/vyper/venom/venom_to_assembly.py +++ b/vyper/venom/venom_to_assembly.py @@ -277,20 +277,19 @@ def _emit_input_operands( seen.add(op) def _prepare_stack_for_function(self, asm, fn: IRFunction, stack: StackModel): - last_param = None + last_param_inst = None for inst in fn.entry.instructions: if inst.opcode != "param": # note: always well defined if the bb is terminated next_liveness = self.liveness.live_vars_at(inst) break - last_param = inst + last_param_inst = inst - assert inst.output is not None # help mypy stack.push(inst.output) # no params (only applies for global entry function) - if last_param is None: + if last_param_inst is None: return to_pop: list[IRVariable] = [] @@ -301,7 +300,7 @@ def _prepare_stack_for_function(self, asm, fn: IRFunction, stack: StackModel): self.popmany(asm, to_pop, stack) - self._optimistic_swap(asm, last_param, next_liveness, stack) + self._optimistic_swap(asm, last_param_inst, next_liveness, stack) def popmany(self, asm, to_pop: Iterable[IRVariable], stack): to_pop = list(to_pop) @@ -434,11 +433,15 @@ def _generate_evm_for_instruction( log_topic_count = inst.operands[0].value assert log_topic_count in [0, 1, 2, 3, 4], "Invalid topic count" operands = inst.operands[1:] + elif opcode == "ret": + # For ret with values, we only treat the return PC as an input operand + # The return values must remain on the stack and are not consumed here + operands = [inst.operands[-1]] else: operands = inst.operands if opcode == "phi": - ret = inst.get_outputs()[0] + ret = inst.output phis = list(inst.get_input_variables()) depth = stack.get_phi_depth(phis) # collapse the arguments to the phi node in the stack. @@ -458,7 +461,6 @@ def _generate_evm_for_instruction( ofst, label = inst.operands assert isinstance(label, IRLabel) # help mypy assembly.extend(_ofst(_as_asm_symbol(label), ofst.value)) - assert isinstance(inst.output, IROperand), "Offset must have output" stack.push(inst.output) return apply_line_numbers(inst, assembly) @@ -502,10 +504,11 @@ def _generate_evm_for_instruction( # with the stack model containing the return value(s), so we fiddle # with the stack model beforehand. - # Step 4: Push instruction's return value to stack + # Step 4: Push instruction's return value(s) to stack stack.pop(len(operands)) - if inst.output is not None: - stack.push(inst.output) + outputs = inst.get_outputs() + for out in outputs: + stack.push(out) # Step 5: Emit the EVM instruction(s) if opcode in _ONE_TO_ONE_INSTRUCTIONS: @@ -603,12 +606,20 @@ def _generate_evm_for_instruction( else: raise Exception(f"Unknown opcode: {opcode}") - # Step 6: Emit instructions output operands (if any) - if inst.output is not None: - if inst.output not in next_liveness: - self.pop(assembly, stack) - else: - self._optimistic_swap(assembly, inst, next_liveness, stack) + # Step 6: Emit instruction output operands (if any) + if len(outputs) == 0: + return apply_line_numbers(inst, assembly) + + dead_outputs = [out for out in outputs if out not in next_liveness] + self.popmany(assembly, dead_outputs, stack) + + live_outputs = [out for out in outputs if out in next_liveness] + if len(live_outputs) == 0: + return apply_line_numbers(inst, assembly) + + # Heuristic scheduling based on the next expected live var + # Use the top-most surviving output to schedule + self._optimistic_swap(assembly, inst, next_liveness, stack) return apply_line_numbers(inst, assembly) @@ -629,8 +640,12 @@ def _optimistic_swap(self, assembly, inst, next_liveness, stack): next_scheduled = next_liveness.last() cost = 0 - if not self.dfg.are_equivalent(inst.output, next_scheduled): - cost = self.swap_op(assembly, stack, next_scheduled) + # Use last output (top-of-stack) when available, else the single output + inst_outputs = inst.get_outputs() + if len(inst_outputs) > 0: + current_top_out = inst_outputs[-1] + if not self.dfg.are_equivalent(current_top_out, next_scheduled): + cost = self.swap_op(assembly, stack, next_scheduled) if DEBUG_SHOW_COST and cost != 0: print("ENTER", inst, file=sys.stderr)